1//===- PresburgerRelation.cpp - MLIR PresburgerRelation Class -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/Presburger/PresburgerRelation.h"
10#include "mlir/Analysis/Presburger/IntegerRelation.h"
11#include "mlir/Analysis/Presburger/MPInt.h"
12#include "mlir/Analysis/Presburger/PWMAFunction.h"
13#include "mlir/Analysis/Presburger/PresburgerSpace.h"
14#include "mlir/Analysis/Presburger/Simplex.h"
15#include "mlir/Analysis/Presburger/Utils.h"
16#include "mlir/Support/LLVM.h"
17#include "mlir/Support/LogicalResult.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallBitVector.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/Support/raw_ostream.h"
22#include <cassert>
23#include <functional>
24#include <optional>
25#include <utility>
26#include <vector>
27
28using namespace mlir;
29using namespace presburger;
30
31PresburgerRelation::PresburgerRelation(const IntegerRelation &disjunct)
32 : space(disjunct.getSpaceWithoutLocals()) {
33 unionInPlace(disjunct);
34}
35
36void PresburgerRelation::setSpace(const PresburgerSpace &oSpace) {
37 assert(space.getNumLocalVars() == 0 && "no locals should be present");
38 space = oSpace;
39 for (IntegerRelation &disjunct : disjuncts)
40 disjunct.setSpaceExceptLocals(space);
41}
42
43void PresburgerRelation::insertVarInPlace(VarKind kind, unsigned pos,
44 unsigned num) {
45 for (IntegerRelation &cs : disjuncts)
46 cs.insertVar(kind, pos, num);
47 space.insertVar(kind, pos, num);
48}
49
50void PresburgerRelation::convertVarKind(VarKind srcKind, unsigned srcPos,
51 unsigned num, VarKind dstKind,
52 unsigned dstPos) {
53 assert(srcKind != VarKind::Local && dstKind != VarKind::Local &&
54 "srcKind/dstKind cannot be local");
55 assert(srcKind != dstKind && "cannot convert variables to the same kind");
56 assert(srcPos + num <= space.getNumVarKind(srcKind) &&
57 "invalid range for source variables");
58 assert(dstPos <= space.getNumVarKind(dstKind) &&
59 "invalid position for destination variables");
60
61 space.convertVarKind(srcKind, srcPos, num, dstKind, dstPos);
62
63 for (IntegerRelation &disjunct : disjuncts)
64 disjunct.convertVarKind(srcKind, varStart: srcPos, varLimit: srcPos + num, dstKind, pos: dstPos);
65}
66
67unsigned PresburgerRelation::getNumDisjuncts() const {
68 return disjuncts.size();
69}
70
71ArrayRef<IntegerRelation> PresburgerRelation::getAllDisjuncts() const {
72 return disjuncts;
73}
74
75const IntegerRelation &PresburgerRelation::getDisjunct(unsigned index) const {
76 assert(index < disjuncts.size() && "index out of bounds!");
77 return disjuncts[index];
78}
79
80/// Mutate this set, turning it into the union of this set and the given
81/// IntegerRelation.
82void PresburgerRelation::unionInPlace(const IntegerRelation &disjunct) {
83 assert(space.isCompatible(disjunct.getSpace()) && "Spaces should match");
84 disjuncts.push_back(Elt: disjunct);
85}
86
87/// Mutate this set, turning it into the union of this set and the given set.
88///
89/// This is accomplished by simply adding all the disjuncts of the given set
90/// to this set.
91void PresburgerRelation::unionInPlace(const PresburgerRelation &set) {
92 assert(space.isCompatible(set.getSpace()) && "Spaces should match");
93
94 if (isObviouslyEqual(set))
95 return;
96
97 if (isObviouslyEmpty()) {
98 disjuncts = set.disjuncts;
99 return;
100 }
101 if (set.isObviouslyEmpty())
102 return;
103
104 if (isObviouslyUniverse())
105 return;
106 if (set.isObviouslyUniverse()) {
107 disjuncts = set.disjuncts;
108 return;
109 }
110
111 for (const IntegerRelation &disjunct : set.disjuncts)
112 unionInPlace(disjunct);
113}
114
115/// Return the union of this set and the given set.
116PresburgerRelation
117PresburgerRelation::unionSet(const PresburgerRelation &set) const {
118 assert(space.isCompatible(set.getSpace()) && "Spaces should match");
119 PresburgerRelation result = *this;
120 result.unionInPlace(set);
121 return result;
122}
123
124/// A point is contained in the union iff any of the parts contain the point.
125bool PresburgerRelation::containsPoint(ArrayRef<MPInt> point) const {
126 return llvm::any_of(Range: disjuncts, P: [&](const IntegerRelation &disjunct) {
127 return (disjunct.containsPointNoLocal(point));
128 });
129}
130
131PresburgerRelation
132PresburgerRelation::getUniverse(const PresburgerSpace &space) {
133 PresburgerRelation result(space);
134 result.unionInPlace(disjunct: IntegerRelation::getUniverse(space));
135 return result;
136}
137
138PresburgerRelation PresburgerRelation::getEmpty(const PresburgerSpace &space) {
139 return PresburgerRelation(space);
140}
141
142// Return the intersection of this set with the given set.
143//
144// We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...)
145// as (S_1 and T_1) or (S_1 and T_2) or ...
146//
147// If S_i or T_j have local variables, then S_i and T_j contains the local
148// variables of both.
149PresburgerRelation
150PresburgerRelation::intersect(const PresburgerRelation &set) const {
151 assert(space.isCompatible(set.getSpace()) && "Spaces should match");
152
153 // If the set is empty or the other set is universe,
154 // directly return the set
155 if (isObviouslyEmpty() || set.isObviouslyUniverse())
156 return *this;
157
158 if (set.isObviouslyEmpty() || isObviouslyUniverse())
159 return set;
160
161 PresburgerRelation result(getSpace());
162 for (const IntegerRelation &csA : disjuncts) {
163 for (const IntegerRelation &csB : set.disjuncts) {
164 IntegerRelation intersection = csA.intersect(other: csB);
165 if (!intersection.isEmpty())
166 result.unionInPlace(disjunct: intersection);
167 }
168 }
169 return result;
170}
171
172PresburgerRelation
173PresburgerRelation::intersectRange(const PresburgerSet &set) const {
174 assert(space.getRangeSpace().isCompatible(set.getSpace()) &&
175 "Range of `this` must be compatible with range of `set`");
176
177 PresburgerRelation other = set;
178 other.insertVarInPlace(kind: VarKind::Domain, pos: 0, num: getNumDomainVars());
179 return intersect(set: other);
180}
181
182PresburgerRelation
183PresburgerRelation::intersectDomain(const PresburgerSet &set) const {
184 assert(space.getDomainSpace().isCompatible(set.getSpace()) &&
185 "Domain of `this` must be compatible with range of `set`");
186
187 PresburgerRelation other = set;
188 other.insertVarInPlace(kind: VarKind::Domain, pos: 0, num: getNumRangeVars());
189 other.inverse();
190 return intersect(set: other);
191}
192
193PresburgerSet PresburgerRelation::getDomainSet() const {
194 PresburgerSet result = PresburgerSet::getEmpty(space: space.getDomainSpace());
195 for (const IntegerRelation &cs : disjuncts)
196 result.unionInPlace(disjunct: cs.getDomainSet());
197 return result;
198}
199
200PresburgerSet PresburgerRelation::getRangeSet() const {
201 PresburgerSet result = PresburgerSet::getEmpty(space: space.getRangeSpace());
202 for (const IntegerRelation &cs : disjuncts)
203 result.unionInPlace(disjunct: cs.getRangeSet());
204 return result;
205}
206
207void PresburgerRelation::inverse() {
208 for (IntegerRelation &cs : disjuncts)
209 cs.inverse();
210
211 if (getNumDisjuncts())
212 setSpace(getDisjunct(index: 0).getSpaceWithoutLocals());
213}
214
215void PresburgerRelation::compose(const PresburgerRelation &rel) {
216 assert(getSpace().getRangeSpace().isCompatible(
217 rel.getSpace().getDomainSpace()) &&
218 "Range of `this` should be compatible with domain of `rel`");
219
220 PresburgerRelation result =
221 PresburgerRelation::getEmpty(space: PresburgerSpace::getRelationSpace(
222 numDomain: getNumDomainVars(), numRange: rel.getNumRangeVars(), numSymbols: getNumSymbolVars()));
223 for (const IntegerRelation &csA : disjuncts) {
224 for (const IntegerRelation &csB : rel.disjuncts) {
225 IntegerRelation composition = csA;
226 composition.compose(rel: csB);
227 if (!composition.isEmpty())
228 result.unionInPlace(disjunct: composition);
229 }
230 }
231 *this = result;
232}
233
234void PresburgerRelation::applyDomain(const PresburgerRelation &rel) {
235 assert(getSpace().getDomainSpace().isCompatible(
236 rel.getSpace().getDomainSpace()) &&
237 "Domain of `this` should be compatible with domain of `rel`");
238
239 inverse();
240 compose(rel);
241 inverse();
242}
243
244void PresburgerRelation::applyRange(const PresburgerRelation &rel) {
245 compose(rel);
246}
247
248static SymbolicLexOpt findSymbolicIntegerLexOpt(const PresburgerRelation &rel,
249 bool isMin) {
250 SymbolicLexOpt result(rel.getSpace());
251 PWMAFunction &lexopt = result.lexopt;
252 PresburgerSet &unboundedDomain = result.unboundedDomain;
253 for (const IntegerRelation &cs : rel.getAllDisjuncts()) {
254 SymbolicLexOpt s(rel.getSpace());
255 if (isMin) {
256 s = cs.findSymbolicIntegerLexMin();
257 lexopt = lexopt.unionLexMin(func: s.lexopt);
258 } else {
259 s = cs.findSymbolicIntegerLexMax();
260 lexopt = lexopt.unionLexMax(func: s.lexopt);
261 }
262 unboundedDomain = unboundedDomain.intersect(set: s.unboundedDomain);
263 }
264 return result;
265}
266
267SymbolicLexOpt PresburgerRelation::findSymbolicIntegerLexMin() const {
268 return findSymbolicIntegerLexOpt(rel: *this, isMin: true);
269}
270
271SymbolicLexOpt PresburgerRelation::findSymbolicIntegerLexMax() const {
272 return findSymbolicIntegerLexOpt(rel: *this, isMin: false);
273}
274
275/// Return the coefficients of the ineq in `rel` specified by `idx`.
276/// `idx` can refer not only to an actual inequality of `rel`, but also
277/// to either of the inequalities that make up an equality in `rel`.
278///
279/// When 0 <= idx < rel.getNumInequalities(), this returns the coeffs of the
280/// idx-th inequality of `rel`.
281///
282/// Otherwise, it is then considered to index into the ineqs corresponding to
283/// eqs of `rel`, and it must hold that
284///
285/// 0 <= idx - rel.getNumInequalities() < 2*getNumEqualities().
286///
287/// For every eq `coeffs == 0` there are two possible ineqs to index into.
288/// The first is coeffs >= 0 and the second is coeffs <= 0.
289static SmallVector<MPInt, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
290 unsigned idx) {
291 assert(idx < rel.getNumInequalities() + 2 * rel.getNumEqualities() &&
292 "idx out of bounds!");
293 if (idx < rel.getNumInequalities())
294 return llvm::to_vector<8>(Range: rel.getInequality(idx));
295
296 idx -= rel.getNumInequalities();
297 ArrayRef<MPInt> eqCoeffs = rel.getEquality(idx: idx / 2);
298
299 if (idx % 2 == 0)
300 return llvm::to_vector<8>(Range&: eqCoeffs);
301 return getNegatedCoeffs(coeffs: eqCoeffs);
302}
303
304PresburgerRelation PresburgerRelation::computeReprWithOnlyDivLocals() const {
305 if (hasOnlyDivLocals())
306 return *this;
307
308 // The result is just the union of the reprs of the disjuncts.
309 PresburgerRelation result(getSpace());
310 for (const IntegerRelation &disjunct : disjuncts)
311 result.unionInPlace(set: disjunct.computeReprWithOnlyDivLocals());
312 return result;
313}
314
315/// Return the set difference b \ s.
316///
317/// In the following, U denotes union, /\ denotes intersection, \ denotes set
318/// difference and ~ denotes complement.
319///
320/// Let s = (U_i s_i). We want b \ (U_i s_i).
321///
322/// Let s_i = /\_j s_ij, where each s_ij is a single inequality. To compute
323/// b \ s_i = b /\ ~s_i, we partition s_i based on the first violated
324/// inequality: ~s_i = (~s_i1) U (s_i1 /\ ~s_i2) U (s_i1 /\ s_i2 /\ ~s_i3) U ...
325/// And the required result is (b /\ ~s_i1) U (b /\ s_i1 /\ ~s_i2) U ...
326/// We recurse by subtracting U_{j > i} S_j from each of these parts and
327/// returning the union of the results. Each equality is handled as a
328/// conjunction of two inequalities.
329///
330/// Note that the same approach works even if an inequality involves a floor
331/// division. For example, the complement of x <= 7*floor(x/7) is still
332/// x > 7*floor(x/7). Since b \ s_i contains the inequalities of both b and s_i
333/// (or the complements of those inequalities), b \ s_i may contain the
334/// divisions present in both b and s_i. Therefore, we need to add the local
335/// division variables of both b and s_i to each part in the result. This means
336/// adding the local variables of both b and s_i, as well as the corresponding
337/// division inequalities to each part. Since the division inequalities are
338/// added to each part, we can skip the parts where the complement of any
339/// division inequality is added, as these parts will become empty anyway.
340///
341/// As a heuristic, we try adding all the constraints and check if simplex
342/// says that the intersection is empty. If it is, then subtracting this
343/// disjuncts is a no-op and we just skip it. Also, in the process we find out
344/// that some constraints are redundant. These redundant constraints are
345/// ignored.
346///
347static PresburgerRelation getSetDifference(IntegerRelation b,
348 const PresburgerRelation &s) {
349 assert(b.getSpace().isCompatible(s.getSpace()) && "Spaces should match");
350 if (b.isEmptyByGCDTest())
351 return PresburgerRelation::getEmpty(space: b.getSpaceWithoutLocals());
352
353 if (!s.hasOnlyDivLocals())
354 return getSetDifference(b, s: s.computeReprWithOnlyDivLocals());
355
356 // Remove duplicate divs up front here to avoid existing
357 // divs disappearing in the call to mergeLocalVars below.
358 b.removeDuplicateDivs();
359
360 PresburgerRelation result =
361 PresburgerRelation::getEmpty(space: b.getSpaceWithoutLocals());
362 Simplex simplex(b);
363
364 // This algorithm is more naturally expressed recursively, but we implement
365 // it iteratively here to avoid issues with stack sizes.
366 //
367 // Each level of the recursion has five stack variables.
368 struct Frame {
369 // A snapshot of the simplex state to rollback to.
370 unsigned simplexSnapshot;
371 // A CountsSnapshot of `b` to rollback to.
372 IntegerRelation::CountsSnapshot bCounts;
373 // The IntegerRelation currently being operated on.
374 IntegerRelation sI;
375 // A list of indexes (see getIneqCoeffsFromIdx) of inequalities to be
376 // processed.
377 SmallVector<unsigned, 8> ineqsToProcess;
378 // The index of the last inequality that was processed at this level.
379 // This is empty when we are coming to this level for the first time.
380 std::optional<unsigned> lastIneqProcessed;
381 };
382 SmallVector<Frame, 2> frames;
383
384 // When we "recurse", we ensure the current frame is stored in `frames` and
385 // increment `level`. When we return, we decrement `level`.
386 unsigned level = 1;
387 while (level > 0) {
388 if (level - 1 >= s.getNumDisjuncts()) {
389 // No more parts to subtract; add to the result and return.
390 result.unionInPlace(disjunct: b);
391 level = frames.size();
392 continue;
393 }
394
395 if (level > frames.size()) {
396 // No frame for this level yet, so we have just recursed into this level.
397 IntegerRelation sI = s.getDisjunct(index: level - 1);
398 // Remove the duplicate divs up front to avoid them possibly disappearing
399 // in the call to mergeLocalVars below.
400 sI.removeDuplicateDivs();
401
402 // Below, we append some additional constraints and ids to b. We want to
403 // rollback b to its initial state before returning, which we will do by
404 // removing all constraints beyond the original number of inequalities
405 // and equalities, so we store these counts first.
406 IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
407 // Similarly, we also want to rollback simplex to its original state.
408 unsigned initialSnapshot = simplex.getSnapshot();
409
410 // Add sI's locals to b, after b's locals. Only those locals of sI which
411 // do not already exist in b will be added. (i.e., duplicate divisions
412 // will not be added.) Also add b's locals to sI, in such a way that both
413 // have the same locals in the same order in the end.
414 b.mergeLocalVars(other&: sI);
415
416 // Find out which inequalities of sI correspond to division inequalities
417 // for the local variables of sI.
418 //
419 // Careful! This has to be done after the merge above; otherwise, the
420 // dividends won't contain the new ids inserted during the merge.
421 std::vector<MaybeLocalRepr> repr(sI.getNumLocalVars());
422 DivisionRepr divs = sI.getLocalReprs(repr: &repr);
423
424 // Mark which inequalities of sI are division inequalities and add all
425 // such inequalities to b.
426 llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
427 2 * sI.getNumEqualities());
428 for (unsigned i = initBCounts.getSpace().getNumLocalVars(),
429 e = sI.getNumLocalVars();
430 i < e; ++i) {
431 assert(
432 repr[i] &&
433 "Subtraction is not supported when a representation of the local "
434 "variables of the subtrahend cannot be found!");
435
436 if (repr[i].kind == ReprKind::Inequality) {
437 unsigned lb = repr[i].repr.inequalityPair.lowerBoundIdx;
438 unsigned ub = repr[i].repr.inequalityPair.upperBoundIdx;
439
440 b.addInequality(inEq: sI.getInequality(idx: lb));
441 b.addInequality(inEq: sI.getInequality(idx: ub));
442
443 assert(lb != ub &&
444 "Upper and lower bounds must be different inequalities!");
445 canIgnoreIneq[lb] = true;
446 canIgnoreIneq[ub] = true;
447 } else {
448 assert(repr[i].kind == ReprKind::Equality &&
449 "ReprKind isn't inequality so should be equality");
450
451 // Consider the case (x) : (x = 3e + 1), where e is a local.
452 // Its complement is (x) : (x = 3e) or (x = 3e + 2).
453 //
454 // This can be computed by considering the set to be
455 // (x) : (x = 3*(x floordiv 3) + 1).
456 //
457 // Now there are no equalities defining divisions; the division is
458 // defined by the standard division equalities for e = x floordiv 3,
459 // i.e., 0 <= x - 3*e <= 2.
460 // So now as before, we add these division inequalities to b. The
461 // equality is now just an ordinary constraint that must be considered
462 // in the remainder of the algorithm. The division inequalities must
463 // need not be considered, same as above, and they automatically will
464 // not be because they were never a part of sI; we just infer them
465 // from the equality and add them only to b.
466 b.addInequality(
467 inEq: getDivLowerBound(dividend: divs.getDividend(i), divisor: divs.getDenom(i),
468 localVarIdx: sI.getVarKindOffset(kind: VarKind::Local) + i));
469 b.addInequality(
470 inEq: getDivUpperBound(dividend: divs.getDividend(i), divisor: divs.getDenom(i),
471 localVarIdx: sI.getVarKindOffset(kind: VarKind::Local) + i));
472 }
473 }
474
475 unsigned offset = simplex.getNumConstraints();
476 unsigned numLocalsAdded =
477 b.getNumLocalVars() - initBCounts.getSpace().getNumLocalVars();
478 simplex.appendVariable(count: numLocalsAdded);
479
480 unsigned snapshotBeforeIntersect = simplex.getSnapshot();
481 simplex.intersectIntegerRelation(rel: sI);
482
483 if (simplex.isEmpty()) {
484 // b /\ s_i is empty, so b \ s_i = b. We move directly to i + 1.
485 // We are ignoring level i completely, so we restore the state
486 // *before* going to the next level.
487 b.truncate(counts: initBCounts);
488 simplex.rollback(snapshot: initialSnapshot);
489 // Recurse. We haven't processed any inequalities and
490 // we don't need to process anything when we return.
491 //
492 // TODO: consider supporting tail recursion directly if this becomes
493 // relevant for performance.
494 frames.push_back(Elt: Frame{.simplexSnapshot: initialSnapshot, .bCounts: initBCounts, .sI: sI,
495 /*ineqsToProcess=*/{},
496 /*lastIneqProcessed=*/{}});
497 ++level;
498 continue;
499 }
500
501 // Equalities are added to simplex as a pair of inequalities.
502 unsigned totalNewSimplexInequalities =
503 2 * sI.getNumEqualities() + sI.getNumInequalities();
504 // Look for redundant constraints among the constraints of sI. We don't
505 // care about redundant constraints in `b` at this point.
506 //
507 // When there are two copies of a constraint in `simplex`, i.e., among the
508 // constraints of `b` and `sI`, only one of them can be marked redundant.
509 // (Assuming no other constraint makes these redundant.)
510 //
511 // In a case where there is one copy in `b` and one in `sI`, we want the
512 // one in `sI` to be marked, not the one in `b`. Therefore, it's not
513 // enough to ignore the constraints of `b` when checking which
514 // constraints `detectRedundant` has marked redundant; we explicitly tell
515 // `detectRedundant` to only mark constraints from `sI` as being
516 // redundant.
517 simplex.detectRedundant(offset, count: totalNewSimplexInequalities);
518 for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
519 canIgnoreIneq[j] = simplex.isMarkedRedundant(constraintIndex: offset + j);
520 simplex.rollback(snapshot: snapshotBeforeIntersect);
521
522 SmallVector<unsigned, 8> ineqsToProcess;
523 ineqsToProcess.reserve(N: totalNewSimplexInequalities);
524 for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
525 if (!canIgnoreIneq[i])
526 ineqsToProcess.push_back(Elt: i);
527
528 if (ineqsToProcess.empty()) {
529 // Nothing to process; return. (we have no frame to pop.)
530 level = frames.size();
531 continue;
532 }
533
534 unsigned simplexSnapshot = simplex.getSnapshot();
535 IntegerRelation::CountsSnapshot bCounts = b.getCounts();
536 frames.push_back(Elt: Frame{.simplexSnapshot: simplexSnapshot, .bCounts: bCounts, .sI: sI, .ineqsToProcess: ineqsToProcess,
537 /*lastIneqProcessed=*/std::nullopt});
538 // We have completed the initial setup for this level.
539 // Fallthrough to the main recursive part below.
540 }
541
542 // For each inequality ineq, we first recurse with the part where ineq
543 // is not satisfied, and then add ineq to b and simplex because
544 // ineq must be satisfied by all later parts.
545 if (level == frames.size()) {
546 Frame &frame = frames.back();
547 if (frame.lastIneqProcessed) {
548 // Let the current value of b be b' and
549 // let the initial value of b when we first came to this level be b.
550 //
551 // b' is equal to b /\ s_i1 /\ s_i2 /\ ... /\ s_i{j-1} /\ ~s_ij.
552 // We had previously recursed with the part where s_ij was not
553 // satisfied; all further parts satisfy s_ij, so we rollback to the
554 // state before adding this complement constraint, and add s_ij to b.
555 simplex.rollback(snapshot: frame.simplexSnapshot);
556 b.truncate(counts: frame.bCounts);
557 SmallVector<MPInt, 8> ineq =
558 getIneqCoeffsFromIdx(rel: frame.sI, idx: *frame.lastIneqProcessed);
559 b.addInequality(inEq: ineq);
560 simplex.addInequality(coeffs: ineq);
561 }
562
563 if (frame.ineqsToProcess.empty()) {
564 // No ineqs left to process; pop this level's frame and return.
565 frames.pop_back();
566 level = frames.size();
567 continue;
568 }
569
570 // "Recurse" with the part where the ineq is not satisfied.
571 frame.bCounts = b.getCounts();
572 frame.simplexSnapshot = simplex.getSnapshot();
573
574 unsigned idx = frame.ineqsToProcess.back();
575 SmallVector<MPInt, 8> ineq =
576 getComplementIneq(ineq: getIneqCoeffsFromIdx(rel: frame.sI, idx));
577 b.addInequality(inEq: ineq);
578 simplex.addInequality(coeffs: ineq);
579
580 frame.ineqsToProcess.pop_back();
581 frame.lastIneqProcessed = idx;
582 ++level;
583 continue;
584 }
585 }
586
587 // Try to simplify the results.
588 result = result.simplify();
589
590 return result;
591}
592
593/// Return the complement of this set.
594PresburgerRelation PresburgerRelation::complement() const {
595 return getSetDifference(b: IntegerRelation::getUniverse(space: getSpace()), s: *this);
596}
597
598/// Return the result of subtract the given set from this set, i.e.,
599/// return `this \ set`.
600PresburgerRelation
601PresburgerRelation::subtract(const PresburgerRelation &set) const {
602 assert(space.isCompatible(set.getSpace()) && "Spaces should match");
603 PresburgerRelation result(getSpace());
604
605 // If we know that the two sets are clearly equal, we can simply return the
606 // empty set.
607 if (isObviouslyEqual(set))
608 return result;
609
610 // We compute (U_i t_i) \ (U_i set_i) as U_i (t_i \ V_i set_i).
611 for (const IntegerRelation &disjunct : disjuncts)
612 result.unionInPlace(set: getSetDifference(b: disjunct, s: set));
613 return result;
614}
615
616/// T is a subset of S iff T \ S is empty, since if T \ S contains a
617/// point then this is a point that is contained in T but not S, and
618/// if T contains a point that is not in S, this also lies in T \ S.
619bool PresburgerRelation::isSubsetOf(const PresburgerRelation &set) const {
620 return this->subtract(set).isIntegerEmpty();
621}
622
623/// Two sets are equal iff they are subsets of each other.
624bool PresburgerRelation::isEqual(const PresburgerRelation &set) const {
625 assert(space.isCompatible(set.getSpace()) && "Spaces should match");
626 return this->isSubsetOf(set) && set.isSubsetOf(set: *this);
627}
628
629bool PresburgerRelation::isObviouslyEqual(const PresburgerRelation &set) const {
630 if (!space.isCompatible(other: set.getSpace()))
631 return false;
632
633 if (getNumDisjuncts() != set.getNumDisjuncts())
634 return false;
635
636 // Compare each disjunct in this PresburgerRelation with the corresponding
637 // disjunct in the other PresburgerRelation.
638 for (unsigned int i = 0, n = getNumDisjuncts(); i < n; ++i) {
639 if (!getDisjunct(index: i).isObviouslyEqual(other: set.getDisjunct(index: i)))
640 return false;
641 }
642 return true;
643}
644
645/// Return true if the Presburger relation represents the universe set, false
646/// otherwise. It is a simple check that only check if the relation has at least
647/// one unconstrained disjunct, indicating the absence of constraints or
648/// conditions.
649bool PresburgerRelation::isObviouslyUniverse() const {
650 for (const IntegerRelation &disjunct : getAllDisjuncts()) {
651 if (disjunct.getNumConstraints() == 0)
652 return true;
653 }
654 return false;
655}
656
657bool PresburgerRelation::isConvexNoLocals() const {
658 return getNumDisjuncts() == 1 && getSpace().getNumLocalVars() == 0;
659}
660
661/// Return true if there is no disjunct, false otherwise.
662bool PresburgerRelation::isObviouslyEmpty() const {
663 return getNumDisjuncts() == 0;
664}
665
666/// Return true if all the sets in the union are known to be integer empty,
667/// false otherwise.
668bool PresburgerRelation::isIntegerEmpty() const {
669 // The set is empty iff all of the disjuncts are empty.
670 return llvm::all_of(Range: disjuncts, P: std::mem_fn(pm: &IntegerRelation::isIntegerEmpty));
671}
672
673bool PresburgerRelation::findIntegerSample(SmallVectorImpl<MPInt> &sample) {
674 // A sample exists iff any of the disjuncts contains a sample.
675 for (const IntegerRelation &disjunct : disjuncts) {
676 if (std::optional<SmallVector<MPInt, 8>> opt =
677 disjunct.findIntegerSample()) {
678 sample = std::move(*opt);
679 return true;
680 }
681 }
682 return false;
683}
684
685std::optional<MPInt> PresburgerRelation::computeVolume() const {
686 assert(getNumSymbolVars() == 0 && "Symbols are not yet supported!");
687 // The sum of the volumes of the disjuncts is a valid overapproximation of the
688 // volume of their union, even if they overlap.
689 MPInt result(0);
690 for (const IntegerRelation &disjunct : disjuncts) {
691 std::optional<MPInt> volume = disjunct.computeVolume();
692 if (!volume)
693 return {};
694 result += *volume;
695 }
696 return result;
697}
698
699/// The SetCoalescer class contains all functionality concerning the coalesce
700/// heuristic. It is built from a `PresburgerRelation` and has the `coalesce()`
701/// function as its main API. The coalesce heuristic simplifies the
702/// representation of a PresburgerRelation. In particular, it removes all
703/// disjuncts which are subsets of other disjuncts in the union and it combines
704/// sets that overlap and can be combined in a convex way.
705class presburger::SetCoalescer {
706
707public:
708 /// Simplifies the representation of a PresburgerSet.
709 PresburgerRelation coalesce();
710
711 /// Construct a SetCoalescer from a PresburgerSet.
712 SetCoalescer(const PresburgerRelation &s);
713
714private:
715 /// The space of the set the SetCoalescer is coalescing.
716 PresburgerSpace space;
717
718 /// The current list of `IntegerRelation`s that the currently coalesced set is
719 /// the union of.
720 SmallVector<IntegerRelation, 2> disjuncts;
721 /// The list of `Simplex`s constructed from the elements of `disjuncts`.
722 SmallVector<Simplex, 2> simplices;
723
724 /// The list of all inversed equalities during typing. This ensures that
725 /// the constraints exist even after the typing function has concluded.
726 SmallVector<SmallVector<MPInt, 2>, 2> negEqs;
727
728 /// `redundantIneqsA` is the inequalities of `a` that are redundant for `b`
729 /// (similarly for `cuttingIneqsA`, `redundantIneqsB`, and `cuttingIneqsB`).
730 SmallVector<ArrayRef<MPInt>, 2> redundantIneqsA;
731 SmallVector<ArrayRef<MPInt>, 2> cuttingIneqsA;
732
733 SmallVector<ArrayRef<MPInt>, 2> redundantIneqsB;
734 SmallVector<ArrayRef<MPInt>, 2> cuttingIneqsB;
735
736 /// Given a Simplex `simp` and one of its inequalities `ineq`, check
737 /// that the facet of `simp` where `ineq` holds as an equality is contained
738 /// within `a`.
739 bool isFacetContained(ArrayRef<MPInt> ineq, Simplex &simp);
740
741 /// Removes redundant constraints from `disjunct`, adds it to `disjuncts` and
742 /// removes the disjuncts at position `i` and `j`. Updates `simplices` to
743 /// reflect the changes. `i` and `j` cannot be equal.
744 void addCoalescedDisjunct(unsigned i, unsigned j,
745 const IntegerRelation &disjunct);
746
747 /// Checks whether `a` and `b` can be combined in a convex sense, if there
748 /// exist cutting inequalities.
749 ///
750 /// An example of this case:
751 /// ___________ ___________
752 /// / / | / / /
753 /// \ \ | / ==> \ /
754 /// \ \ | / \ /
755 /// \___\|/ \_____/
756 ///
757 ///
758 LogicalResult coalescePairCutCase(unsigned i, unsigned j);
759
760 /// Types the inequality `ineq` according to its `IneqType` for `simp` into
761 /// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
762 /// inequalities were encountered. Otherwise, returns failure.
763 LogicalResult typeInequality(ArrayRef<MPInt> ineq, Simplex &simp);
764
765 /// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and
766 /// -`eq` >= 0 according to their `IneqType` for `simp` into
767 /// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
768 /// inequalities were encountered. Otherwise, returns failure.
769 LogicalResult typeEquality(ArrayRef<MPInt> eq, Simplex &simp);
770
771 /// Replaces the element at position `i` with the last element and erases
772 /// the last element for both `disjuncts` and `simplices`.
773 void eraseDisjunct(unsigned i);
774
775 /// Attempts to coalesce the two IntegerRelations at position `i` and `j`
776 /// in `disjuncts` in-place. Returns whether the disjuncts were
777 /// successfully coalesced. The simplices in `simplices` need to be the ones
778 /// constructed from `disjuncts`. At this point, there are no empty
779 /// disjuncts in `disjuncts` left.
780 LogicalResult coalescePair(unsigned i, unsigned j);
781};
782
783/// Constructs a `SetCoalescer` from a `PresburgerRelation`. Only adds non-empty
784/// `IntegerRelation`s to the `disjuncts` vector.
785SetCoalescer::SetCoalescer(const PresburgerRelation &s) : space(s.getSpace()) {
786
787 disjuncts = s.disjuncts;
788
789 simplices.reserve(N: s.getNumDisjuncts());
790 // Note that disjuncts.size() changes during the loop.
791 for (unsigned i = 0; i < disjuncts.size();) {
792 disjuncts[i].removeRedundantConstraints();
793 Simplex simp(disjuncts[i]);
794 if (simp.isEmpty()) {
795 disjuncts[i] = disjuncts[disjuncts.size() - 1];
796 disjuncts.pop_back();
797 continue;
798 }
799 ++i;
800 simplices.push_back(Elt: simp);
801 }
802}
803
804/// Simplifies the representation of a PresburgerSet.
805PresburgerRelation SetCoalescer::coalesce() {
806 // For all tuples of IntegerRelations, check whether they can be
807 // coalesced. When coalescing is successful, the contained IntegerRelation
808 // is swapped with the last element of `disjuncts` and subsequently erased
809 // and similarly for simplices.
810 for (unsigned i = 0; i < disjuncts.size();) {
811
812 // TODO: This does some comparisons two times (index 0 with 1 and index 1
813 // with 0).
814 bool broken = false;
815 for (unsigned j = 0, e = disjuncts.size(); j < e; ++j) {
816 negEqs.clear();
817 redundantIneqsA.clear();
818 redundantIneqsB.clear();
819 cuttingIneqsA.clear();
820 cuttingIneqsB.clear();
821 if (i == j)
822 continue;
823 if (coalescePair(i, j).succeeded()) {
824 broken = true;
825 break;
826 }
827 }
828
829 // Only if the inner loop was not broken, i is incremented. This is
830 // required as otherwise, if a coalescing occurs, the IntegerRelation
831 // now at position i is not compared.
832 if (!broken)
833 ++i;
834 }
835
836 PresburgerRelation newSet = PresburgerRelation::getEmpty(space);
837 for (const IntegerRelation &disjunct : disjuncts)
838 newSet.unionInPlace(disjunct);
839
840 return newSet;
841}
842
843/// Given a Simplex `simp` and one of its inequalities `ineq`, check
844/// that all inequalities of `cuttingIneqsB` are redundant for the facet of
845/// `simp` where `ineq` holds as an equality is contained within `a`.
846bool SetCoalescer::isFacetContained(ArrayRef<MPInt> ineq, Simplex &simp) {
847 SimplexRollbackScopeExit scopeExit(simp);
848 simp.addEquality(coeffs: ineq);
849 return llvm::all_of(Range&: cuttingIneqsB, P: [&simp](ArrayRef<MPInt> curr) {
850 return simp.isRedundantInequality(coeffs: curr);
851 });
852}
853
854void SetCoalescer::addCoalescedDisjunct(unsigned i, unsigned j,
855 const IntegerRelation &disjunct) {
856 assert(i != j && "The indices must refer to different disjuncts");
857 unsigned n = disjuncts.size();
858 if (j == n - 1) {
859 // This case needs special handling since position `n` - 1 is removed
860 // from the vector, hence the `IntegerRelation` at position `n` - 2 is
861 // lost otherwise.
862 disjuncts[i] = disjuncts[n - 2];
863 disjuncts.pop_back();
864 disjuncts[n - 2] = disjunct;
865 disjuncts[n - 2].removeRedundantConstraints();
866
867 simplices[i] = simplices[n - 2];
868 simplices.pop_back();
869 simplices[n - 2] = Simplex(disjuncts[n - 2]);
870
871 } else {
872 // Other possible edge cases are correct since for `j` or `i` == `n` -
873 // 2, the `IntegerRelation` at position `n` - 2 should be lost. The
874 // case `i` == `n` - 1 makes the first following statement a noop.
875 // Hence, in this case the same thing is done as above, but with `j`
876 // rather than `i`.
877 disjuncts[i] = disjuncts[n - 1];
878 disjuncts[j] = disjuncts[n - 2];
879 disjuncts.pop_back();
880 disjuncts[n - 2] = disjunct;
881 disjuncts[n - 2].removeRedundantConstraints();
882
883 simplices[i] = simplices[n - 1];
884 simplices[j] = simplices[n - 2];
885 simplices.pop_back();
886 simplices[n - 2] = Simplex(disjuncts[n - 2]);
887 }
888}
889
890/// Given two polyhedra `a` and `b` at positions `i` and `j` in
891/// `disjuncts` and `redundantIneqsA` being the inequalities of `a` that
892/// are redundant for `b` (similarly for `cuttingIneqsA`, `redundantIneqsB`,
893/// and `cuttingIneqsB`), Checks whether the facets of all cutting
894/// inequalites of `a` are contained in `b`. If so, a new polyhedron
895/// consisting of all redundant inequalites of `a` and `b` and all
896/// equalities of both is created.
897///
898/// An example of this case:
899/// ___________ ___________
900/// / / | / / /
901/// \ \ | / ==> \ /
902/// \ \ | / \ /
903/// \___\|/ \_____/
904///
905///
906LogicalResult SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
907 /// All inequalities of `b` need to be redundant. We already know that the
908 /// redundant ones are, so only the cutting ones remain to be checked.
909 Simplex &simp = simplices[i];
910 IntegerRelation &disjunct = disjuncts[i];
911 if (llvm::any_of(Range&: cuttingIneqsA, P: [this, &simp](ArrayRef<MPInt> curr) {
912 return !isFacetContained(ineq: curr, simp);
913 }))
914 return failure();
915 IntegerRelation newSet(disjunct.getSpace());
916
917 for (ArrayRef<MPInt> curr : redundantIneqsA)
918 newSet.addInequality(inEq: curr);
919
920 for (ArrayRef<MPInt> curr : redundantIneqsB)
921 newSet.addInequality(inEq: curr);
922
923 addCoalescedDisjunct(i, j, disjunct: newSet);
924 return success();
925}
926
927LogicalResult SetCoalescer::typeInequality(ArrayRef<MPInt> ineq,
928 Simplex &simp) {
929 Simplex::IneqType type = simp.findIneqType(coeffs: ineq);
930 if (type == Simplex::IneqType::Redundant)
931 redundantIneqsB.push_back(Elt: ineq);
932 else if (type == Simplex::IneqType::Cut)
933 cuttingIneqsB.push_back(Elt: ineq);
934 else
935 return failure();
936 return success();
937}
938
939LogicalResult SetCoalescer::typeEquality(ArrayRef<MPInt> eq, Simplex &simp) {
940 if (typeInequality(ineq: eq, simp).failed())
941 return failure();
942 negEqs.push_back(Elt: getNegatedCoeffs(coeffs: eq));
943 ArrayRef<MPInt> inv(negEqs.back());
944 if (typeInequality(ineq: inv, simp).failed())
945 return failure();
946 return success();
947}
948
949void SetCoalescer::eraseDisjunct(unsigned i) {
950 assert(simplices.size() == disjuncts.size() &&
951 "simplices and disjuncts must be equally as long");
952 disjuncts[i] = disjuncts.back();
953 disjuncts.pop_back();
954 simplices[i] = simplices.back();
955 simplices.pop_back();
956}
957
958LogicalResult SetCoalescer::coalescePair(unsigned i, unsigned j) {
959
960 IntegerRelation &a = disjuncts[i];
961 IntegerRelation &b = disjuncts[j];
962 /// Handling of local ids is not yet implemented, so these cases are
963 /// skipped.
964 /// TODO: implement local id support.
965 if (a.getNumLocalVars() != 0 || b.getNumLocalVars() != 0)
966 return failure();
967 Simplex &simpA = simplices[i];
968 Simplex &simpB = simplices[j];
969
970 // Organize all inequalities and equalities of `a` according to their type
971 // for `b` into `redundantIneqsA` and `cuttingIneqsA` (and vice versa for
972 // all inequalities of `b` according to their type in `a`). If a separate
973 // inequality is encountered during typing, the two IntegerRelations
974 // cannot be coalesced.
975 for (int k = 0, e = a.getNumInequalities(); k < e; ++k)
976 if (typeInequality(ineq: a.getInequality(idx: k), simp&: simpB).failed())
977 return failure();
978
979 for (int k = 0, e = a.getNumEqualities(); k < e; ++k)
980 if (typeEquality(eq: a.getEquality(idx: k), simp&: simpB).failed())
981 return failure();
982
983 std::swap(LHS&: redundantIneqsA, RHS&: redundantIneqsB);
984 std::swap(LHS&: cuttingIneqsA, RHS&: cuttingIneqsB);
985
986 for (int k = 0, e = b.getNumInequalities(); k < e; ++k)
987 if (typeInequality(ineq: b.getInequality(idx: k), simp&: simpA).failed())
988 return failure();
989
990 for (int k = 0, e = b.getNumEqualities(); k < e; ++k)
991 if (typeEquality(eq: b.getEquality(idx: k), simp&: simpA).failed())
992 return failure();
993
994 // If there are no cutting inequalities of `a`, `b` is contained
995 // within `a`.
996 if (cuttingIneqsA.empty()) {
997 eraseDisjunct(i: j);
998 return success();
999 }
1000
1001 // Try to apply the cut case
1002 if (coalescePairCutCase(i, j).succeeded())
1003 return success();
1004
1005 // Swap the vectors to compare the pair (j,i) instead of (i,j).
1006 std::swap(LHS&: redundantIneqsA, RHS&: redundantIneqsB);
1007 std::swap(LHS&: cuttingIneqsA, RHS&: cuttingIneqsB);
1008
1009 // If there are no cutting inequalities of `a`, `b` is contained
1010 // within `a`.
1011 if (cuttingIneqsA.empty()) {
1012 eraseDisjunct(i);
1013 return success();
1014 }
1015
1016 // Try to apply the cut case
1017 if (coalescePairCutCase(i: j, j: i).succeeded())
1018 return success();
1019
1020 return failure();
1021}
1022
1023PresburgerRelation PresburgerRelation::coalesce() const {
1024 return SetCoalescer(*this).coalesce();
1025}
1026
1027bool PresburgerRelation::hasOnlyDivLocals() const {
1028 return llvm::all_of(Range: disjuncts, P: [](const IntegerRelation &rel) {
1029 return rel.hasOnlyDivLocals();
1030 });
1031}
1032
1033PresburgerRelation PresburgerRelation::simplify() const {
1034 PresburgerRelation origin = *this;
1035 PresburgerRelation result = PresburgerRelation(getSpace());
1036 for (IntegerRelation &disjunct : origin.disjuncts) {
1037 disjunct.simplify();
1038 if (!disjunct.isObviouslyEmpty())
1039 result.unionInPlace(disjunct);
1040 }
1041 return result;
1042}
1043
1044bool PresburgerRelation::isFullDim() const {
1045 return llvm::any_of(Range: getAllDisjuncts(), P: [&](IntegerRelation disjunct) {
1046 return disjunct.isFullDim();
1047 });
1048}
1049
1050void PresburgerRelation::print(raw_ostream &os) const {
1051 os << "Number of Disjuncts: " << getNumDisjuncts() << "\n";
1052 for (const IntegerRelation &disjunct : disjuncts) {
1053 disjunct.print(os);
1054 os << '\n';
1055 }
1056}
1057
1058void PresburgerRelation::dump() const { print(os&: llvm::errs()); }
1059
1060PresburgerSet PresburgerSet::getUniverse(const PresburgerSpace &space) {
1061 PresburgerSet result(space);
1062 result.unionInPlace(disjunct: IntegerPolyhedron::getUniverse(space));
1063 return result;
1064}
1065
1066PresburgerSet PresburgerSet::getEmpty(const PresburgerSpace &space) {
1067 return PresburgerSet(space);
1068}
1069
1070PresburgerSet::PresburgerSet(const IntegerPolyhedron &disjunct)
1071 : PresburgerRelation(disjunct) {}
1072
1073PresburgerSet::PresburgerSet(const PresburgerRelation &set)
1074 : PresburgerRelation(set) {}
1075
1076PresburgerSet PresburgerSet::unionSet(const PresburgerRelation &set) const {
1077 return PresburgerSet(PresburgerRelation::unionSet(set));
1078}
1079
1080PresburgerSet PresburgerSet::intersect(const PresburgerRelation &set) const {
1081 return PresburgerSet(PresburgerRelation::intersect(set));
1082}
1083
1084PresburgerSet PresburgerSet::complement() const {
1085 return PresburgerSet(PresburgerRelation::complement());
1086}
1087
1088PresburgerSet PresburgerSet::subtract(const PresburgerRelation &set) const {
1089 return PresburgerSet(PresburgerRelation::subtract(set));
1090}
1091
1092PresburgerSet PresburgerSet::coalesce() const {
1093 return PresburgerSet(PresburgerRelation::coalesce());
1094}
1095

source code of mlir/lib/Analysis/Presburger/PresburgerRelation.cpp