1//===- FlatLinearValueConstraints.cpp - Linear Constraint -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis//FlatLinearValueConstraints.h"
10
11#include "mlir/Analysis/Presburger/LinearTransform.h"
12#include "mlir/Analysis/Presburger/PresburgerSpace.h"
13#include "mlir/Analysis/Presburger/Simplex.h"
14#include "mlir/Analysis/Presburger/Utils.h"
15#include "mlir/IR/AffineExprVisitor.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/IntegerSet.h"
18#include "mlir/Support/LLVM.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SmallPtrSet.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/InterleavedRange.h"
24#include "llvm/Support/raw_ostream.h"
25#include <optional>
26
27#define DEBUG_TYPE "flat-value-constraints"
28
29using namespace mlir;
30using namespace presburger;
31
32//===----------------------------------------------------------------------===//
33// AffineExprFlattener
34//===----------------------------------------------------------------------===//
35
36namespace {
37
38// See comments for SimpleAffineExprFlattener.
39// An AffineExprFlattenerWithLocalVars extends a SimpleAffineExprFlattener by
40// recording constraint information associated with mod's, floordiv's, and
41// ceildiv's in FlatLinearConstraints 'localVarCst'.
42struct AffineExprFlattener : public SimpleAffineExprFlattener {
43 using SimpleAffineExprFlattener::SimpleAffineExprFlattener;
44
45 // Constraints connecting newly introduced local variables (for mod's and
46 // div's) to existing (dimensional and symbolic) ones. These are always
47 // inequalities.
48 IntegerPolyhedron localVarCst;
49
50 AffineExprFlattener(unsigned nDims, unsigned nSymbols)
51 : SimpleAffineExprFlattener(nDims, nSymbols),
52 localVarCst(PresburgerSpace::getSetSpace(numDims: nDims, numSymbols: nSymbols)) {};
53
54private:
55 // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr).
56 // The local variable added is always a floordiv of a pure add/mul affine
57 // function of other variables, coefficients of which are specified in
58 // `dividend' and with respect to the positive constant `divisor'. localExpr
59 // is the simplified tree expression (AffineExpr) corresponding to the
60 // quantifier.
61 void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
62 AffineExpr localExpr) override {
63 SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
64 // Update localVarCst.
65 localVarCst.addLocalFloorDiv(dividend, divisor);
66 }
67
68 LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
69 ArrayRef<int64_t> rhs,
70 AffineExpr localExpr) override {
71 // AffineExprFlattener does not support semi-affine expressions.
72 return failure();
73 }
74};
75
76// A SemiAffineExprFlattener is an AffineExprFlattenerWithLocalVars that adds
77// conservative bounds for semi-affine expressions (given assumptions hold). If
78// the assumptions required to add the semi-affine bounds are found not to hold
79// the final constraints set will be empty/inconsistent. If the assumptions are
80// never contradicted the final bounds still only will be correct if the
81// assumptions hold.
82struct SemiAffineExprFlattener : public AffineExprFlattener {
83 using AffineExprFlattener::AffineExprFlattener;
84
85 LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
86 ArrayRef<int64_t> rhs,
87 AffineExpr localExpr) override {
88 auto result =
89 SimpleAffineExprFlattener::addLocalIdSemiAffine(lhs, rhs, localExpr);
90 assert(succeeded(result) &&
91 "unexpected failure in SimpleAffineExprFlattener");
92 (void)result;
93
94 if (localExpr.getKind() == AffineExprKind::Mod) {
95 // Given two numbers a and b, division is defined as:
96 //
97 // a = bq + r
98 // 0 <= r < |b| (where |x| is the absolute value of x)
99 //
100 // q = a floordiv b
101 // r = a mod b
102
103 // Add a new local variable (r) to represent the mod.
104 unsigned rPos = localVarCst.appendVar(kind: VarKind::Local);
105
106 // r >= 0 (Can ALWAYS be added)
107 localVarCst.addBound(type: BoundType::LB, pos: rPos, value: 0);
108
109 // r < b (Can be added if b > 0, which we assume here)
110 ArrayRef<int64_t> b = rhs;
111 SmallVector<int64_t> bSubR(b);
112 bSubR.insert(I: bSubR.begin() + rPos, Elt: -1);
113 // Note: bSubR = b - r
114 // So this adds the bound b - r >= 1 (equivalent to r < b)
115 localVarCst.addBound(type: BoundType::LB, expr: bSubR, value: 1);
116
117 // Note: The assumption of b > 0 is based on the affine expression docs,
118 // which state "RHS of mod is always a constant or a symbolic expression
119 // with a positive value." (see AffineExprKind in AffineExpr.h). If this
120 // assumption does not hold constraints (added above) are a contradiction.
121
122 return success();
123 }
124
125 // TODO: Support other semi-affine expressions.
126 return failure();
127 }
128};
129
130} // namespace
131
132// Flattens the expressions in map. Returns failure if 'expr' was unable to be
133// flattened. For example two specific cases:
134// 1. an unhandled semi-affine expressions is found.
135// 2. has poison expression (i.e., division by zero).
136static LogicalResult
137getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
138 unsigned numSymbols,
139 std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
140 FlatLinearConstraints *localVarCst,
141 bool addConservativeSemiAffineBounds = false) {
142 if (exprs.empty()) {
143 if (localVarCst)
144 *localVarCst = FlatLinearConstraints(numDims, numSymbols);
145 return success();
146 }
147
148 auto flattenExprs = [&](AffineExprFlattener &flattener) -> LogicalResult {
149 // Use the same flattener to simplify each expression successively. This way
150 // local variables / expressions are shared.
151 for (auto expr : exprs) {
152 auto flattenResult = flattener.walkPostOrder(expr);
153 if (failed(Result: flattenResult))
154 return failure();
155 }
156
157 assert(flattener.operandExprStack.size() == exprs.size());
158 flattenedExprs->clear();
159 flattenedExprs->assign(first: flattener.operandExprStack.begin(),
160 last: flattener.operandExprStack.end());
161
162 if (localVarCst)
163 localVarCst->clearAndCopyFrom(other: flattener.localVarCst);
164
165 return success();
166 };
167
168 if (addConservativeSemiAffineBounds) {
169 SemiAffineExprFlattener flattener(numDims, numSymbols);
170 return flattenExprs(flattener);
171 }
172
173 AffineExprFlattener flattener(numDims, numSymbols);
174 return flattenExprs(flattener);
175}
176
177// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
178// be flattened (an unhandled semi-affine was found).
179LogicalResult mlir::getFlattenedAffineExpr(
180 AffineExpr expr, unsigned numDims, unsigned numSymbols,
181 SmallVectorImpl<int64_t> *flattenedExpr, FlatLinearConstraints *localVarCst,
182 bool addConservativeSemiAffineBounds) {
183 std::vector<SmallVector<int64_t, 8>> flattenedExprs;
184 LogicalResult ret =
185 ::getFlattenedAffineExprs(exprs: {expr}, numDims, numSymbols, flattenedExprs: &flattenedExprs,
186 localVarCst, addConservativeSemiAffineBounds);
187 *flattenedExpr = flattenedExprs[0];
188 return ret;
189}
190
191/// Flattens the expressions in map. Returns failure if 'expr' was unable to be
192/// flattened (i.e., an unhandled semi-affine was found).
193LogicalResult mlir::getFlattenedAffineExprs(
194 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
195 FlatLinearConstraints *localVarCst, bool addConservativeSemiAffineBounds) {
196 if (map.getNumResults() == 0) {
197 if (localVarCst)
198 *localVarCst =
199 FlatLinearConstraints(map.getNumDims(), map.getNumSymbols());
200 return success();
201 }
202 return ::getFlattenedAffineExprs(
203 exprs: map.getResults(), numDims: map.getNumDims(), numSymbols: map.getNumSymbols(), flattenedExprs,
204 localVarCst, addConservativeSemiAffineBounds);
205}
206
207LogicalResult mlir::getFlattenedAffineExprs(
208 IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
209 FlatLinearConstraints *localVarCst) {
210 if (set.getNumConstraints() == 0) {
211 if (localVarCst)
212 *localVarCst =
213 FlatLinearConstraints(set.getNumDims(), set.getNumSymbols());
214 return success();
215 }
216 return ::getFlattenedAffineExprs(exprs: set.getConstraints(), numDims: set.getNumDims(),
217 numSymbols: set.getNumSymbols(), flattenedExprs,
218 localVarCst);
219}
220
221//===----------------------------------------------------------------------===//
222// FlatLinearConstraints
223//===----------------------------------------------------------------------===//
224
225// Similar to `composeMap` except that no Values need be associated with the
226// constraint system nor are they looked at -- the dimensions and symbols of
227// `other` are expected to correspond 1:1 to `this` system.
228LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) {
229 assert(other.getNumDims() == getNumDimVars() && "dim mismatch");
230 assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
231
232 std::vector<SmallVector<int64_t, 8>> flatExprs;
233 if (failed(Result: flattenAlignedMapAndMergeLocals(map: other, flattenedExprs: &flatExprs)))
234 return failure();
235 assert(flatExprs.size() == other.getNumResults());
236
237 // Add dimensions corresponding to the map's results.
238 insertDimVar(/*pos=*/0, /*num=*/other.getNumResults());
239
240 // We add one equality for each result connecting the result dim of the map to
241 // the other variables.
242 // E.g.: if the expression is 16*i0 + i1, and this is the r^th
243 // iteration/result of the value map, we are adding the equality:
244 // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
245 // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
246 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
247 const auto &flatExpr = flatExprs[r];
248 assert(flatExpr.size() >= other.getNumInputs() + 1);
249
250 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
251 // Set the coefficient for this result to one.
252 eqToAdd[r] = 1;
253
254 // Dims and symbols.
255 for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
256 // Negate `eq[r]` since the newly added dimension will be set to this one.
257 eqToAdd[e + i] = -flatExpr[i];
258 }
259 // Local columns of `eq` are at the beginning.
260 unsigned j = getNumDimVars() + getNumSymbolVars();
261 unsigned end = flatExpr.size() - 1;
262 for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
263 eqToAdd[j] = -flatExpr[i];
264 }
265
266 // Constant term.
267 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
268
269 // Add the equality connecting the result of the map to this constraint set.
270 addEquality(eq: eqToAdd);
271 }
272
273 return success();
274}
275
276// Determine whether the variable at 'pos' (say var_r) can be expressed as
277// modulo of another known variable (say var_n) w.r.t a constant. For example,
278// if the following constraints hold true:
279// ```
280// 0 <= var_r <= divisor - 1
281// var_n - (divisor * q_expr) = var_r
282// ```
283// where `var_n` is a known variable (called dividend), and `q_expr` is an
284// `AffineExpr` (called the quotient expression), `var_r` can be written as:
285//
286// `var_r = var_n mod divisor`.
287//
288// Additionally, in a special case of the above constaints where `q_expr` is an
289// variable itself that is not yet known (say `var_q`), it can be written as a
290// floordiv in the following way:
291//
292// `var_q = var_n floordiv divisor`.
293//
294// First 'num' dimensional variables starting at 'offset' are
295// derived/to-be-derived in terms of the remaining variables. The remaining
296// variables are assigned trivial affine expressions in `memo`. For example,
297// memo is initilized as follows for a `cst` with 5 dims, when offset=2, num=2:
298// memo ==> d0 d1 . . d2 ...
299// cst ==> c0 c1 c2 c3 c4 ...
300//
301// Returns true if the above mod or floordiv are detected, updating 'memo' with
302// these new expressions. Returns false otherwise.
303static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos,
304 unsigned offset, unsigned num, int64_t lbConst,
305 int64_t ubConst, MLIRContext *context,
306 SmallVectorImpl<AffineExpr> &memo) {
307 assert(pos < cst.getNumVars() && "invalid position");
308
309 // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can
310 // be determined.
311 if (lbConst != 0 || ubConst < 1)
312 return false;
313 int64_t divisor = ubConst + 1;
314
315 // Check for the aforementioned conditions in each equality.
316 for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
317 curEquality < numEqualities; curEquality++) {
318 int64_t coefficientAtPos = cst.atEq64(i: curEquality, j: pos);
319 // If current equality does not involve `var_r`, continue to the next
320 // equality.
321 if (coefficientAtPos == 0)
322 continue;
323
324 // Constant term should be 0 in this equality.
325 if (cst.atEq64(i: curEquality, j: cst.getNumCols() - 1) != 0)
326 continue;
327
328 // Traverse through the equality and construct the dividend expression
329 // `dividendExpr`, to contain all the variables which are known and are
330 // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
331 // `dividendExpr` gets simplified into a single variable `var_n` discussed
332 // above.
333 auto dividendExpr = getAffineConstantExpr(constant: 0, context);
334
335 // Track the terms that go into quotient expression, later used to detect
336 // additional floordiv.
337 unsigned quotientCount = 0;
338 int quotientPosition = -1;
339 int quotientSign = 1;
340
341 // Consider each term in the current equality.
342 unsigned curVar, e;
343 for (curVar = 0, e = cst.getNumDimAndSymbolVars(); curVar < e; ++curVar) {
344 // Ignore var_r.
345 if (curVar == pos)
346 continue;
347 int64_t coefficientOfCurVar = cst.atEq64(i: curEquality, j: curVar);
348 // Ignore vars that do not contribute to the current equality.
349 if (coefficientOfCurVar == 0)
350 continue;
351 // Check if the current var goes into the quotient expression.
352 if (coefficientOfCurVar % (divisor * coefficientAtPos) == 0) {
353 quotientCount++;
354 quotientPosition = curVar;
355 quotientSign = (coefficientOfCurVar * coefficientAtPos) > 0 ? 1 : -1;
356 continue;
357 }
358 // Variables that are part of dividendExpr should be known.
359 if (!memo[curVar])
360 break;
361 // Append the current variable to the dividend expression.
362 dividendExpr = dividendExpr + memo[curVar] * coefficientOfCurVar;
363 }
364
365 // Can't construct expression as it depends on a yet uncomputed var.
366 if (curVar < e)
367 continue;
368
369 // Express `var_r` in terms of the other vars collected so far.
370 if (coefficientAtPos > 0)
371 dividendExpr = (-dividendExpr).floorDiv(v: coefficientAtPos);
372 else
373 dividendExpr = dividendExpr.floorDiv(v: -coefficientAtPos);
374
375 // Simplify the expression.
376 dividendExpr = simplifyAffineExpr(expr: dividendExpr, numDims: cst.getNumDimVars(),
377 numSymbols: cst.getNumSymbolVars());
378 // Only if the final dividend expression is just a single var (which we call
379 // `var_n`), we can proceed.
380 // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
381 // to dims themselves.
382 auto dimExpr = dyn_cast<AffineDimExpr>(Val&: dividendExpr);
383 if (!dimExpr)
384 continue;
385
386 // Express `var_r` as `var_n % divisor` and store the expression in `memo`.
387 if (quotientCount >= 1) {
388 // Find the column corresponding to `dimExpr`. `num` columns starting at
389 // `offset` correspond to previously unknown variables. The column
390 // corresponding to the trivially known `dimExpr` can be on either side
391 // of these.
392 unsigned dimExprPos = dimExpr.getPosition();
393 unsigned dimExprCol = dimExprPos < offset ? dimExprPos : dimExprPos + num;
394 auto ub = cst.getConstantBound64(type: BoundType::UB, pos: dimExprCol);
395 // If `var_n` has an upperbound that is less than the divisor, mod can be
396 // eliminated altogether.
397 if (ub && *ub < divisor)
398 memo[pos] = dimExpr;
399 else
400 memo[pos] = dimExpr % divisor;
401 // If a unique quotient `var_q` was seen, it can be expressed as
402 // `var_n floordiv divisor`.
403 if (quotientCount == 1 && !memo[quotientPosition])
404 memo[quotientPosition] = dimExpr.floorDiv(v: divisor) * quotientSign;
405
406 return true;
407 }
408 }
409 return false;
410}
411
412/// Check if the pos^th variable can be expressed as a floordiv of an affine
413/// function of other variables (where the divisor is a positive constant)
414/// given the initial set of expressions in `exprs`. If it can be, the
415/// corresponding position in `exprs` is set as the detected affine expr. For
416/// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can
417/// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
418/// <= i <= 32q + 31 => q = i floordiv 32.
419static bool detectAsFloorDiv(const FlatLinearConstraints &cst, unsigned pos,
420 MLIRContext *context,
421 SmallVectorImpl<AffineExpr> &exprs) {
422 assert(pos < cst.getNumVars() && "invalid position");
423
424 // Get upper-lower bound pair for this variable.
425 SmallVector<bool, 8> foundRepr(cst.getNumVars(), false);
426 for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i)
427 if (exprs[i])
428 foundRepr[i] = true;
429
430 SmallVector<int64_t, 8> dividend(cst.getNumCols());
431 unsigned divisor;
432 auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor);
433
434 // No upper-lower bound pair found for this var.
435 if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality)
436 return false;
437
438 // Construct the dividend expression.
439 auto dividendExpr = getAffineConstantExpr(constant: dividend.back(), context);
440 for (unsigned c = 0, f = cst.getNumVars(); c < f; c++)
441 if (dividend[c] != 0)
442 dividendExpr = dividendExpr + dividend[c] * exprs[c];
443
444 // Successfully detected the floordiv.
445 exprs[pos] = dividendExpr.floorDiv(v: divisor);
446 return true;
447}
448
449void FlatLinearConstraints::dumpRow(ArrayRef<int64_t> row,
450 bool fixedColWidth) const {
451 unsigned ncols = getNumCols();
452 bool firstNonZero = true;
453 for (unsigned j = 0; j < ncols; j++) {
454 if (j == ncols - 1) {
455 // Constant.
456 if (row[j] == 0 && !firstNonZero) {
457 if (fixedColWidth)
458 llvm::errs().indent(NumSpaces: 7);
459 } else {
460 llvm::errs() << ((row[j] >= 0) ? "+ " : "") << row[j] << ' ';
461 }
462 } else {
463 std::string var = std::string("c_") + std::to_string(val: j);
464 if (row[j] == 1)
465 llvm::errs() << "+ " << var << ' ';
466 else if (row[j] == -1)
467 llvm::errs() << "- " << var << ' ';
468 else if (row[j] >= 2)
469 llvm::errs() << "+ " << row[j] << '*' << var << ' ';
470 else if (row[j] <= -2)
471 llvm::errs() << "- " << -row[j] << '*' << var << ' ';
472 else if (fixedColWidth)
473 // Zero coeff.
474 llvm::errs().indent(NumSpaces: 7);
475 if (row[j] != 0)
476 firstNonZero = false;
477 }
478 }
479}
480
481void FlatLinearConstraints::dumpPretty() const {
482 assert(hasConsistentState());
483 llvm::errs() << "Constraints (" << getNumDimVars() << " dims, "
484 << getNumSymbolVars() << " symbols, " << getNumLocalVars()
485 << " locals), (" << getNumConstraints() << " constraints)\n";
486 auto dumpConstraint = [&](unsigned rowPos, bool isEq) {
487 // Is it the first non-zero entry?
488 SmallVector<int64_t> row =
489 isEq ? getEquality64(idx: rowPos) : getInequality64(idx: rowPos);
490 dumpRow(row);
491 llvm::errs() << (isEq ? "=" : ">=") << " 0\n";
492 };
493
494 for (unsigned i = 0, e = getNumInequalities(); i < e; i++)
495 dumpConstraint(i, /*isEq=*/false);
496 for (unsigned i = 0, e = getNumEqualities(); i < e; i++)
497 dumpConstraint(i, /*isEq=*/true);
498 llvm::errs() << '\n';
499}
500
501std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound(
502 unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
503 ArrayRef<AffineExpr> localExprs, MLIRContext *context,
504 bool closedUB) const {
505 assert(pos + offset < getNumDimVars() && "invalid dim start pos");
506 assert(symStartPos >= (pos + offset) && "invalid sym start pos");
507 assert(getNumLocalVars() == localExprs.size() &&
508 "incorrect local exprs count");
509
510 SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
511 getLowerAndUpperBoundIndices(pos: pos + offset, lbIndices: &lbIndices, ubIndices: &ubIndices, eqIndices: &eqIndices,
512 offset, num);
513
514 /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
515 auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
516 b.clear();
517 for (unsigned i = 0, e = a.size(); i < e; ++i) {
518 if (i < offset || i >= offset + num)
519 b.push_back(Elt: a[i]);
520 }
521 };
522
523 SmallVector<int64_t, 8> lb, ub;
524 SmallVector<AffineExpr, 4> lbExprs;
525 unsigned dimCount = symStartPos - num;
526 unsigned symCount = getNumDimAndSymbolVars() - symStartPos;
527 lbExprs.reserve(N: lbIndices.size() + eqIndices.size());
528 // Lower bound expressions.
529 for (auto idx : lbIndices) {
530 auto ineq = getInequality64(idx);
531 // Extract the lower bound (in terms of other coeff's + const), i.e., if
532 // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
533 // - 1.
534 addCoeffs(ineq, lb);
535 std::transform(first: lb.begin(), last: lb.end(), result: lb.begin(), unary_op: std::negate<int64_t>());
536 auto expr =
537 getAffineExprFromFlatForm(flatExprs: lb, numDims: dimCount, numSymbols: symCount, localExprs, context);
538 // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
539 int64_t divisor = std::abs(i: ineq[pos + offset]);
540 expr = (expr + divisor - 1).floorDiv(v: divisor);
541 lbExprs.push_back(Elt: expr);
542 }
543
544 SmallVector<AffineExpr, 4> ubExprs;
545 ubExprs.reserve(N: ubIndices.size() + eqIndices.size());
546 // Upper bound expressions.
547 for (auto idx : ubIndices) {
548 auto ineq = getInequality64(idx);
549 // Extract the upper bound (in terms of other coeff's + const).
550 addCoeffs(ineq, ub);
551 auto expr =
552 getAffineExprFromFlatForm(flatExprs: ub, numDims: dimCount, numSymbols: symCount, localExprs, context);
553 expr = expr.floorDiv(v: std::abs(i: ineq[pos + offset]));
554 int64_t ubAdjustment = closedUB ? 0 : 1;
555 ubExprs.push_back(Elt: expr + ubAdjustment);
556 }
557
558 // Equalities. It's both a lower and a upper bound.
559 SmallVector<int64_t, 4> b;
560 for (auto idx : eqIndices) {
561 auto eq = getEquality64(idx);
562 addCoeffs(eq, b);
563 if (eq[pos + offset] > 0)
564 std::transform(first: b.begin(), last: b.end(), result: b.begin(), unary_op: std::negate<int64_t>());
565
566 // Extract the upper bound (in terms of other coeff's + const).
567 auto expr =
568 getAffineExprFromFlatForm(flatExprs: b, numDims: dimCount, numSymbols: symCount, localExprs, context);
569 expr = expr.floorDiv(v: std::abs(i: eq[pos + offset]));
570 // Upper bound is exclusive.
571 ubExprs.push_back(Elt: expr + 1);
572 // Lower bound.
573 expr =
574 getAffineExprFromFlatForm(flatExprs: b, numDims: dimCount, numSymbols: symCount, localExprs, context);
575 expr = expr.ceilDiv(v: std::abs(i: eq[pos + offset]));
576 lbExprs.push_back(Elt: expr);
577 }
578
579 auto lbMap = AffineMap::get(dimCount, symbolCount: symCount, results: lbExprs, context);
580 auto ubMap = AffineMap::get(dimCount, symbolCount: symCount, results: ubExprs, context);
581
582 return {lbMap, ubMap};
583}
584
585/// Express the pos^th identifier of `cst` as an affine expression in
586/// terms of other identifiers, if they are available in `exprs`, using the
587/// equality at position `idx` in `cs`t. Populates `exprs` with such an
588/// expression if possible, and return true. Returns false otherwise.
589static bool detectAsExpr(const FlatLinearConstraints &cst, unsigned pos,
590 unsigned idx, MLIRContext *context,
591 SmallVectorImpl<AffineExpr> &exprs) {
592 // Initialize with a `0` expression.
593 auto expr = getAffineConstantExpr(constant: 0, context);
594
595 // Traverse `idx`th equality and construct the possible affine expression in
596 // terms of known identifiers.
597 unsigned j, e;
598 for (j = 0, e = cst.getNumVars(); j < e; ++j) {
599 if (j == pos)
600 continue;
601 int64_t c = cst.atEq64(i: idx, j);
602 if (c == 0)
603 continue;
604 // If any of the involved IDs hasn't been found yet, we can't proceed.
605 if (!exprs[j])
606 break;
607 expr = expr + exprs[j] * c;
608 }
609 if (j < e)
610 // Can't construct expression as it depends on a yet uncomputed
611 // identifier.
612 return false;
613
614 // Add constant term to AffineExpr.
615 expr = expr + cst.atEq64(i: idx, j: cst.getNumVars());
616 int64_t vPos = cst.atEq64(i: idx, j: pos);
617 assert(vPos != 0 && "expected non-zero here");
618 if (vPos > 0)
619 expr = (-expr).floorDiv(v: vPos);
620 else
621 // vPos < 0.
622 expr = expr.floorDiv(v: -vPos);
623 // Successfully constructed expression.
624 exprs[pos] = expr;
625 return true;
626}
627
628/// Compute a representation of `num` identifiers starting at `offset` in `cst`
629/// as affine expressions involving other known identifiers. Each identifier's
630/// expression (in terms of known identifiers) is populated into `memo`.
631static void computeUnknownVars(const FlatLinearConstraints &cst,
632 MLIRContext *context, unsigned offset,
633 unsigned num,
634 SmallVectorImpl<AffineExpr> &memo) {
635 // Initialize dimensional and symbolic variables.
636 for (unsigned i = 0, e = cst.getNumDimVars(); i < e; i++) {
637 if (i < offset)
638 memo[i] = getAffineDimExpr(position: i, context);
639 else if (i >= offset + num)
640 memo[i] = getAffineDimExpr(position: i - num, context);
641 }
642 for (unsigned i = cst.getNumDimVars(), e = cst.getNumDimAndSymbolVars();
643 i < e; i++)
644 memo[i] = getAffineSymbolExpr(position: i - cst.getNumDimVars(), context);
645
646 bool changed;
647 do {
648 changed = false;
649 // Identify yet unknown variables as constants or mod's / floordiv's of
650 // other variables if possible.
651 for (unsigned pos = 0, f = cst.getNumVars(); pos < f; pos++) {
652 if (memo[pos])
653 continue;
654
655 auto lbConst = cst.getConstantBound64(type: BoundType::LB, pos);
656 auto ubConst = cst.getConstantBound64(type: BoundType::UB, pos);
657 if (lbConst.has_value() && ubConst.has_value()) {
658 // Detect equality to a constant.
659 if (*lbConst == *ubConst) {
660 memo[pos] = getAffineConstantExpr(constant: *lbConst, context);
661 changed = true;
662 continue;
663 }
664
665 // Detect a variable as modulo of another variable w.r.t a
666 // constant.
667 if (detectAsMod(cst, pos, offset, num, lbConst: *lbConst, ubConst: *ubConst, context,
668 memo)) {
669 changed = true;
670 continue;
671 }
672 }
673
674 // Detect a variable as a floordiv of an affine function of other
675 // variables (divisor is a positive constant).
676 if (detectAsFloorDiv(cst, pos, context, exprs&: memo)) {
677 changed = true;
678 continue;
679 }
680
681 // Detect a variable as an expression of other variables.
682 std::optional<unsigned> idx;
683 if (!(idx = cst.findConstraintWithNonZeroAt(colIdx: pos, /*isEq=*/true)))
684 continue;
685
686 if (detectAsExpr(cst, pos, idx: *idx, context, exprs&: memo)) {
687 changed = true;
688 continue;
689 }
690 }
691 // This loop is guaranteed to reach a fixed point - since once an
692 // variable's explicit form is computed (in memo[pos]), it's not updated
693 // again.
694 } while (changed);
695}
696
697/// Computes the lower and upper bounds of the first 'num' dimensional
698/// variables (starting at 'offset') as affine maps of the remaining
699/// variables (dimensional and symbolic variables). Local variables are
700/// themselves explicitly computed as affine functions of other variables in
701/// this process if needed.
702void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num,
703 MLIRContext *context,
704 SmallVectorImpl<AffineMap> *lbMaps,
705 SmallVectorImpl<AffineMap> *ubMaps,
706 bool closedUB) {
707 assert(offset + num <= getNumDimVars() && "invalid range");
708
709 // Basic simplification.
710 normalizeConstraintsByGCD();
711
712 LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for variables at positions ["
713 << offset << ", " << offset + num << ")\n");
714 LLVM_DEBUG(dumpPretty());
715
716 // Record computed/detected variables.
717 SmallVector<AffineExpr, 8> memo(getNumVars());
718 computeUnknownVars(cst: *this, context, offset, num, memo);
719
720 int64_t ubAdjustment = closedUB ? 0 : 1;
721
722 // Set the lower and upper bound maps for all the variables that were
723 // computed as affine expressions of the rest as the "detected expr" and
724 // "detected expr + 1" respectively; set the undetected ones to null.
725 std::optional<FlatLinearConstraints> tmpClone;
726 for (unsigned pos = 0; pos < num; pos++) {
727 unsigned numMapDims = getNumDimVars() - num;
728 unsigned numMapSymbols = getNumSymbolVars();
729 AffineExpr expr = memo[pos + offset];
730 if (expr)
731 expr = simplifyAffineExpr(expr, numDims: numMapDims, numSymbols: numMapSymbols);
732
733 AffineMap &lbMap = (*lbMaps)[pos];
734 AffineMap &ubMap = (*ubMaps)[pos];
735
736 if (expr) {
737 lbMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols, result: expr);
738 ubMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols, result: expr + ubAdjustment);
739 } else {
740 // TODO: Whenever there are local variables in the dependence
741 // constraints, we'll conservatively over-approximate, since we don't
742 // always explicitly compute them above (in the while loop).
743 if (getNumLocalVars() == 0) {
744 // Work on a copy so that we don't update this constraint system.
745 if (!tmpClone) {
746 tmpClone.emplace(args: FlatLinearConstraints(*this));
747 // Removing redundant inequalities is necessary so that we don't get
748 // redundant loop bounds.
749 tmpClone->removeRedundantInequalities();
750 }
751 std::tie(args&: lbMap, args&: ubMap) = tmpClone->getLowerAndUpperBound(
752 pos, offset, num, symStartPos: getNumDimVars(), /*localExprs=*/{}, context,
753 closedUB);
754 }
755
756 // If the above fails, we'll just use the constant lower bound and the
757 // constant upper bound (if they exist) as the slice bounds.
758 // TODO: being conservative for the moment in cases that
759 // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
760 // fixed (b/126426796).
761 if (!lbMap || lbMap.getNumResults() != 1) {
762 LLVM_DEBUG(llvm::dbgs()
763 << "WARNING: Potentially over-approximating slice lb\n");
764 auto lbConst = getConstantBound64(type: BoundType::LB, pos: pos + offset);
765 if (lbConst.has_value()) {
766 lbMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols,
767 result: getAffineConstantExpr(constant: *lbConst, context));
768 }
769 }
770 if (!ubMap || ubMap.getNumResults() != 1) {
771 LLVM_DEBUG(llvm::dbgs()
772 << "WARNING: Potentially over-approximating slice ub\n");
773 auto ubConst = getConstantBound64(type: BoundType::UB, pos: pos + offset);
774 if (ubConst.has_value()) {
775 ubMap = AffineMap::get(
776 dimCount: numMapDims, symbolCount: numMapSymbols,
777 result: getAffineConstantExpr(constant: *ubConst + ubAdjustment, context));
778 }
779 }
780 }
781
782 LLVM_DEBUG(llvm::dbgs() << "Slice bounds:\n");
783 LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos + offset)
784 << ", expr: " << lbMap << '\n');
785 LLVM_DEBUG(llvm::dbgs() << "ub map for pos = " << Twine(pos + offset)
786 << ", expr: " << ubMap << '\n');
787 }
788}
789
790LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
791 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
792 bool addConservativeSemiAffineBounds) {
793 FlatLinearConstraints localCst;
794 if (failed(Result: getFlattenedAffineExprs(map, flattenedExprs, localVarCst: &localCst,
795 addConservativeSemiAffineBounds))) {
796 LLVM_DEBUG(llvm::dbgs()
797 << "composition unimplemented for semi-affine maps\n");
798 return failure();
799 }
800
801 // Add localCst information.
802 if (localCst.getNumLocalVars() > 0) {
803 unsigned numLocalVars = getNumLocalVars();
804 // Insert local dims of localCst at the beginning.
805 insertLocalVar(/*pos=*/0, /*num=*/localCst.getNumLocalVars());
806 // Insert local dims of `this` at the end of localCst.
807 localCst.appendLocalVar(/*num=*/numLocalVars);
808 // Dimensions of localCst and this constraint set match. Append localCst to
809 // this constraint set.
810 append(other: localCst);
811 }
812
813 return success();
814}
815
816LogicalResult FlatLinearConstraints::addBound(
817 BoundType type, unsigned pos, AffineMap boundMap, bool isClosedBound,
818 AddConservativeSemiAffineBounds addSemiAffineBounds) {
819 assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch");
820 assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
821 assert(pos < getNumDimAndSymbolVars() && "invalid position");
822 assert((type != BoundType::EQ || isClosedBound) &&
823 "EQ bound must be closed.");
824
825 // Equality follows the logic of lower bound except that we add an equality
826 // instead of an inequality.
827 assert((type != BoundType::EQ || boundMap.getNumResults() == 1) &&
828 "single result expected");
829 bool lower = type == BoundType::LB || type == BoundType::EQ;
830
831 std::vector<SmallVector<int64_t, 8>> flatExprs;
832 if (failed(Result: flattenAlignedMapAndMergeLocals(
833 map: boundMap, flattenedExprs: &flatExprs,
834 addConservativeSemiAffineBounds: addSemiAffineBounds == AddConservativeSemiAffineBounds::Yes)))
835 return failure();
836 assert(flatExprs.size() == boundMap.getNumResults());
837
838 // Add one (in)equality for each result.
839 for (const auto &flatExpr : flatExprs) {
840 SmallVector<int64_t> ineq(getNumCols(), 0);
841 // Dims and symbols.
842 for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
843 ineq[j] = lower ? -flatExpr[j] : flatExpr[j];
844 }
845 // Invalid bound: pos appears in `boundMap`.
846 // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or
847 // its callers to prevent invalid bounds from being added.
848 if (ineq[pos] != 0)
849 continue;
850 ineq[pos] = lower ? 1 : -1;
851 // Local columns of `ineq` are at the beginning.
852 unsigned j = getNumDimVars() + getNumSymbolVars();
853 unsigned end = flatExpr.size() - 1;
854 for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
855 ineq[j] = lower ? -flatExpr[i] : flatExpr[i];
856 }
857 // Make the bound closed in if flatExpr is open. The inequality is always
858 // created in the upper bound form, so the adjustment is -1.
859 int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1;
860 // Constant term.
861 ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1]
862 : flatExpr[flatExpr.size() - 1]) +
863 boundAdjustment;
864 type == BoundType::EQ ? addEquality(eq: ineq) : addInequality(inEq: ineq);
865 }
866
867 return success();
868}
869
870LogicalResult FlatLinearConstraints::addBound(
871 BoundType type, unsigned pos, AffineMap boundMap,
872 AddConservativeSemiAffineBounds addSemiAffineBounds) {
873 return addBound(type, pos, boundMap,
874 /*isClosedBound=*/type != BoundType::UB, addSemiAffineBounds);
875}
876
877/// Compute an explicit representation for local vars. For all systems coming
878/// from MLIR integer sets, maps, or expressions where local vars were
879/// introduced to model floordivs and mods, this always succeeds.
880LogicalResult
881FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo,
882 MLIRContext *context) const {
883 unsigned numDims = getNumDimVars();
884 unsigned numSyms = getNumSymbolVars();
885
886 // Initialize dimensional and symbolic variables.
887 for (unsigned i = 0; i < numDims; i++)
888 memo[i] = getAffineDimExpr(position: i, context);
889 for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
890 memo[i] = getAffineSymbolExpr(position: i - numDims, context);
891
892 bool changed;
893 do {
894 // Each time `changed` is true at the end of this iteration, one or more
895 // local vars would have been detected as floordivs and set in memo; so the
896 // number of null entries in memo[...] strictly reduces; so this converges.
897 changed = false;
898 for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i)
899 if (!memo[numDims + numSyms + i] &&
900 detectAsFloorDiv(cst: *this, /*pos=*/numDims + numSyms + i, context, exprs&: memo))
901 changed = true;
902 } while (changed);
903
904 ArrayRef<AffineExpr> localExprs =
905 ArrayRef<AffineExpr>(memo).take_back(N: getNumLocalVars());
906 return success(
907 IsSuccess: llvm::all_of(Range&: localExprs, P: [](AffineExpr expr) { return expr; }));
908}
909
910/// Given an equality or inequality (`isEquality` used to disambiguate) of `cst`
911/// at `idx`, traverse and sum up `AffineExpr`s of all known ids other than the
912/// `pos`th. Known `AffineExpr`s are given in `exprs` (unknowns are null). If
913/// the equality/inequality contains any unknown id, return None. Otherwise
914/// return sum as `AffineExpr`.
915static std::optional<AffineExpr> getAsExpr(const FlatLinearConstraints &cst,
916 unsigned pos, MLIRContext *context,
917 ArrayRef<AffineExpr> exprs,
918 unsigned idx, bool isEquality) {
919 // Initialize with a `0` expression.
920 auto expr = getAffineConstantExpr(constant: 0, context);
921
922 SmallVector<int64_t, 8> row =
923 isEquality ? cst.getEquality64(idx) : cst.getInequality64(idx);
924
925 // Traverse `idx`th equality and construct the possible affine expression in
926 // terms of known identifiers.
927 unsigned j, e;
928 for (j = 0, e = cst.getNumVars(); j < e; ++j) {
929 if (j == pos)
930 continue;
931 int64_t c = row[j];
932 if (c == 0)
933 continue;
934 // If any of the involved IDs hasn't been found yet, we can't proceed.
935 if (!exprs[j])
936 break;
937 expr = expr + exprs[j] * c;
938 }
939 if (j < e)
940 // Can't construct expression as it depends on a yet uncomputed
941 // identifier.
942 return std::nullopt;
943
944 // Add constant term to AffineExpr.
945 expr = expr + row[cst.getNumVars()];
946 return expr;
947}
948
949std::optional<int64_t> FlatLinearConstraints::getConstantBoundOnDimSize(
950 MLIRContext *context, unsigned pos, AffineMap *lb, AffineMap *ub,
951 unsigned *minLbPos, unsigned *minUbPos) const {
952
953 assert(pos < getNumDimVars() && "Invalid identifier position");
954
955 auto freeOfUnknownLocalVars = [&](ArrayRef<int64_t> cst,
956 ArrayRef<AffineExpr> whiteListCols) {
957 for (int i = getNumDimAndSymbolVars(), e = cst.size() - 1; i < e; ++i) {
958 if (whiteListCols[i] && whiteListCols[i].isSymbolicOrConstant())
959 continue;
960 if (cst[i] != 0)
961 return false;
962 }
963 return true;
964 };
965
966 // Detect the necesary local variables first.
967 SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
968 (void)computeLocalVars(memo, context);
969
970 // Find an equality for 'pos'^th identifier that equates it to some function
971 // of the symbolic identifiers (+ constant).
972 int eqPos = findEqualityToConstant(pos, /*symbolic=*/true);
973 // If the equality involves a local var that can not be expressed as a
974 // symbolic or constant affine expression, we bail out.
975 if (eqPos != -1 && freeOfUnknownLocalVars(getEquality64(idx: eqPos), memo)) {
976 // This identifier can only take a single value.
977 if (lb && detectAsExpr(cst: *this, pos, idx: eqPos, context, exprs&: memo)) {
978 AffineExpr equalityExpr =
979 simplifyAffineExpr(expr: memo[pos], numDims: 0, numSymbols: getNumSymbolVars());
980 *lb = AffineMap::get(/*dimCount=*/0, symbolCount: getNumSymbolVars(), result: equalityExpr);
981 if (ub)
982 *ub = *lb;
983 }
984 if (minLbPos)
985 *minLbPos = eqPos;
986 if (minUbPos)
987 *minUbPos = eqPos;
988 return 1;
989 }
990
991 // Positions of constraints that are lower/upper bounds on the variable.
992 SmallVector<unsigned, 4> lbIndices, ubIndices;
993
994 // Note inequalities that give lower and upper bounds.
995 getLowerAndUpperBoundIndices(pos, lbIndices: &lbIndices, ubIndices: &ubIndices,
996 /*eqIndices=*/nullptr, /*offset=*/0,
997 /*num=*/getNumDimVars());
998
999 std::optional<int64_t> minDiff = std::nullopt;
1000 unsigned minLbPosition = 0, minUbPosition = 0;
1001 AffineExpr minLbExpr, minUbExpr;
1002
1003 // Traverse each lower bound and upper bound pair, to compute the difference
1004 // between them.
1005 for (unsigned ubPos : ubIndices) {
1006 // Construct sum of all ids other than `pos`th in the given upper bound row.
1007 std::optional<AffineExpr> maybeUbExpr =
1008 getAsExpr(cst: *this, pos, context, exprs: memo, idx: ubPos, /*isEquality=*/false);
1009 if (!maybeUbExpr.has_value() || !(*maybeUbExpr).isSymbolicOrConstant())
1010 continue;
1011
1012 // Canonical form of an inequality that constrains the upper bound on
1013 // an id `x_i` is of the form:
1014 // `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` <= -1.
1015 // Therefore the upper bound on `x_i` will be
1016 // `(
1017 // sum(c_j*x_j) where j != i
1018 // +
1019 // c_0
1020 // )
1021 // /
1022 // -(c_i)`. Divison here is a floorDiv.
1023 AffineExpr ubExpr = maybeUbExpr->floorDiv(v: -atIneq64(i: ubPos, j: pos));
1024 assert(-atIneq64(ubPos, pos) > 0 && "invalid upper bound index");
1025
1026 // Go over each lower bound.
1027 for (unsigned lbPos : lbIndices) {
1028 // Construct sum of all ids other than `pos`th in the given lower bound
1029 // row.
1030 std::optional<AffineExpr> maybeLbExpr =
1031 getAsExpr(cst: *this, pos, context, exprs: memo, idx: lbPos, /*isEquality=*/false);
1032 if (!maybeLbExpr.has_value() || !(*maybeLbExpr).isSymbolicOrConstant())
1033 continue;
1034
1035 // Canonical form of an inequality that is constraining the lower bound
1036 // on an id `x_i is of the form:
1037 // `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` >= 1.
1038 // Therefore upperBound on `x_i` will be
1039 // `-(
1040 // sum(c_j*x_j) where j != i
1041 // +
1042 // c_0
1043 // )
1044 // /
1045 // c_i`. Divison here is a ceilDiv.
1046 int64_t divisor = atIneq64(i: lbPos, j: pos);
1047 // We convert the `ceilDiv` for floordiv with the formula:
1048 // `expr ceildiv divisor is (expr + divisor - 1) floordiv divisor`,
1049 // since uniformly keeping divisons as `floorDiv` helps their
1050 // simplification.
1051 AffineExpr lbExpr = (-(*maybeLbExpr) + divisor - 1).floorDiv(v: divisor);
1052 assert(atIneq64(lbPos, pos) > 0 && "invalid lower bound index");
1053
1054 AffineExpr difference =
1055 simplifyAffineExpr(expr: ubExpr - lbExpr + 1, numDims: 0, numSymbols: getNumSymbolVars());
1056 // If the difference is not constant, ignore the lower bound - upper bound
1057 // pair.
1058 auto constantDiff = dyn_cast<AffineConstantExpr>(Val&: difference);
1059 if (!constantDiff)
1060 continue;
1061
1062 int64_t diffValue = constantDiff.getValue();
1063 // This bound is non-negative by definition.
1064 diffValue = std::max<int64_t>(a: diffValue, b: 0);
1065 if (!minDiff || diffValue < *minDiff) {
1066 minDiff = diffValue;
1067 minLbPosition = lbPos;
1068 minUbPosition = ubPos;
1069 minLbExpr = lbExpr;
1070 minUbExpr = ubExpr;
1071 }
1072 }
1073 }
1074
1075 // Populate outputs where available and needed.
1076 if (lb && minDiff) {
1077 *lb = AffineMap::get(/*dimCount=*/0, symbolCount: getNumSymbolVars(), result: minLbExpr);
1078 }
1079 if (ub)
1080 *ub = AffineMap::get(/*dimCount=*/0, symbolCount: getNumSymbolVars(), result: minUbExpr);
1081 if (minLbPos)
1082 *minLbPos = minLbPosition;
1083 if (minUbPos)
1084 *minUbPos = minUbPosition;
1085
1086 return minDiff;
1087}
1088
1089IntegerSet FlatLinearConstraints::getAsIntegerSet(MLIRContext *context) const {
1090 if (getNumConstraints() == 0)
1091 // Return universal set (always true): 0 == 0.
1092 return IntegerSet::get(dimCount: getNumDimVars(), symbolCount: getNumSymbolVars(),
1093 constraints: getAffineConstantExpr(/*constant=*/0, context),
1094 /*eqFlags=*/true);
1095
1096 // Construct local references.
1097 SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
1098
1099 if (failed(Result: computeLocalVars(memo, context))) {
1100 // Check if the local variables without an explicit representation have
1101 // zero coefficients everywhere.
1102 SmallVector<unsigned> noLocalRepVars;
1103 unsigned numDimsSymbols = getNumDimAndSymbolVars();
1104 for (unsigned i = numDimsSymbols, e = getNumVars(); i < e; ++i) {
1105 if (!memo[i] && !isColZero(/*pos=*/i))
1106 noLocalRepVars.push_back(Elt: i - numDimsSymbols);
1107 }
1108 if (!noLocalRepVars.empty()) {
1109 LLVM_DEBUG({
1110 llvm::dbgs() << "local variables at position(s) "
1111 << llvm::interleaved(noLocalRepVars)
1112 << " do not have an explicit representation in:\n";
1113 this->dump();
1114 });
1115 return IntegerSet();
1116 }
1117 }
1118
1119 ArrayRef<AffineExpr> localExprs =
1120 ArrayRef<AffineExpr>(memo).take_back(N: getNumLocalVars());
1121
1122 // Construct the IntegerSet from the equalities/inequalities.
1123 unsigned numDims = getNumDimVars();
1124 unsigned numSyms = getNumSymbolVars();
1125
1126 SmallVector<bool, 16> eqFlags(getNumConstraints());
1127 std::fill(first: eqFlags.begin(), last: eqFlags.begin() + getNumEqualities(), value: true);
1128 std::fill(first: eqFlags.begin() + getNumEqualities(), last: eqFlags.end(), value: false);
1129
1130 SmallVector<AffineExpr, 8> exprs;
1131 exprs.reserve(N: getNumConstraints());
1132
1133 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1134 exprs.push_back(Elt: getAffineExprFromFlatForm(flatExprs: getEquality64(idx: i), numDims,
1135 numSymbols: numSyms, localExprs, context));
1136 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
1137 exprs.push_back(Elt: getAffineExprFromFlatForm(flatExprs: getInequality64(idx: i), numDims,
1138 numSymbols: numSyms, localExprs, context));
1139 return IntegerSet::get(dimCount: numDims, symbolCount: numSyms, constraints: exprs, eqFlags);
1140}
1141
1142//===----------------------------------------------------------------------===//
1143// FlatLinearValueConstraints
1144//===----------------------------------------------------------------------===//
1145
1146// Construct from an IntegerSet.
1147FlatLinearValueConstraints::FlatLinearValueConstraints(IntegerSet set,
1148 ValueRange operands)
1149 : FlatLinearConstraints(set.getNumInequalities(), set.getNumEqualities(),
1150 set.getNumDims() + set.getNumSymbols() + 1,
1151 set.getNumDims(), set.getNumSymbols(),
1152 /*numLocals=*/0) {
1153 assert((operands.empty() || set.getNumInputs() == operands.size()) &&
1154 "operand count mismatch");
1155 // Set the values for the non-local variables.
1156 for (unsigned i = 0, e = operands.size(); i < e; ++i)
1157 setValue(pos: i, val: operands[i]);
1158
1159 // Flatten expressions and add them to the constraint system.
1160 std::vector<SmallVector<int64_t, 8>> flatExprs;
1161 FlatLinearConstraints localVarCst;
1162 if (failed(Result: getFlattenedAffineExprs(set, flattenedExprs: &flatExprs, localVarCst: &localVarCst))) {
1163 assert(false && "flattening unimplemented for semi-affine integer sets");
1164 return;
1165 }
1166 assert(flatExprs.size() == set.getNumConstraints());
1167 insertVar(kind: VarKind::Local, pos: getNumVarKind(kind: VarKind::Local),
1168 /*num=*/localVarCst.getNumLocalVars());
1169
1170 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
1171 const auto &flatExpr = flatExprs[i];
1172 assert(flatExpr.size() == getNumCols());
1173 if (set.getEqFlags()[i]) {
1174 addEquality(eq: flatExpr);
1175 } else {
1176 addInequality(inEq: flatExpr);
1177 }
1178 }
1179 // Add the other constraints involving local vars from flattening.
1180 append(other: localVarCst);
1181}
1182
1183unsigned FlatLinearValueConstraints::appendDimVar(ValueRange vals) {
1184 unsigned pos = getNumDimVars();
1185 return insertVar(kind: VarKind::SetDim, pos, vals);
1186}
1187
1188unsigned FlatLinearValueConstraints::appendSymbolVar(ValueRange vals) {
1189 unsigned pos = getNumSymbolVars();
1190 return insertVar(kind: VarKind::Symbol, pos, vals);
1191}
1192
1193unsigned FlatLinearValueConstraints::insertDimVar(unsigned pos,
1194 ValueRange vals) {
1195 return insertVar(kind: VarKind::SetDim, pos, vals);
1196}
1197
1198unsigned FlatLinearValueConstraints::insertSymbolVar(unsigned pos,
1199 ValueRange vals) {
1200 return insertVar(kind: VarKind::Symbol, pos, vals);
1201}
1202
1203unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
1204 unsigned num) {
1205 unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
1206
1207 return absolutePos;
1208}
1209
1210unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
1211 ValueRange vals) {
1212 assert(!vals.empty() && "expected ValueRange with Values.");
1213 assert(kind != VarKind::Local &&
1214 "values cannot be attached to local variables.");
1215 unsigned num = vals.size();
1216 unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
1217
1218 // If a Value is provided, insert it; otherwise use std::nullopt.
1219 for (unsigned i = 0, e = vals.size(); i < e; ++i)
1220 if (vals[i])
1221 setValue(pos: absolutePos + i, val: vals[i]);
1222
1223 return absolutePos;
1224}
1225
1226/// Checks if two constraint systems are in the same space, i.e., if they are
1227/// associated with the same set of variables, appearing in the same order.
1228static bool areVarsAligned(const FlatLinearValueConstraints &a,
1229 const FlatLinearValueConstraints &b) {
1230 if (a.getNumDomainVars() != b.getNumDomainVars() ||
1231 a.getNumRangeVars() != b.getNumRangeVars() ||
1232 a.getNumSymbolVars() != b.getNumSymbolVars())
1233 return false;
1234 SmallVector<std::optional<Value>> aMaybeValues = a.getMaybeValues(),
1235 bMaybeValues = b.getMaybeValues();
1236 return std::equal(first1: aMaybeValues.begin(), last1: aMaybeValues.end(),
1237 first2: bMaybeValues.begin(), last2: bMaybeValues.end());
1238}
1239
1240/// Calls areVarsAligned to check if two constraint systems have the same set
1241/// of variables in the same order.
1242bool FlatLinearValueConstraints::areVarsAlignedWithOther(
1243 const FlatLinearConstraints &other) {
1244 return areVarsAligned(a: *this, b: other);
1245}
1246
1247/// Checks if the SSA values associated with `cst`'s variables in range
1248/// [start, end) are unique.
1249static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
1250 const FlatLinearValueConstraints &cst, unsigned start, unsigned end) {
1251
1252 assert(start <= cst.getNumDimAndSymbolVars() &&
1253 "Start position out of bounds");
1254 assert(end <= cst.getNumDimAndSymbolVars() && "End position out of bounds");
1255
1256 if (start >= end)
1257 return true;
1258
1259 SmallPtrSet<Value, 8> uniqueVars;
1260 SmallVector<std::optional<Value>, 8> maybeValuesAll = cst.getMaybeValues();
1261 ArrayRef<std::optional<Value>> maybeValues = {maybeValuesAll.data() + start,
1262 maybeValuesAll.data() + end};
1263
1264 for (std::optional<Value> val : maybeValues)
1265 if (val && !uniqueVars.insert(Ptr: *val).second)
1266 return false;
1267
1268 return true;
1269}
1270
1271/// Checks if the SSA values associated with `cst`'s variables are unique.
1272static bool LLVM_ATTRIBUTE_UNUSED
1273areVarsUnique(const FlatLinearValueConstraints &cst) {
1274 return areVarsUnique(cst, start: 0, end: cst.getNumDimAndSymbolVars());
1275}
1276
1277/// Checks if the SSA values associated with `cst`'s variables of kind `kind`
1278/// are unique.
1279static bool LLVM_ATTRIBUTE_UNUSED
1280areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) {
1281
1282 if (kind == VarKind::SetDim)
1283 return areVarsUnique(cst, start: 0, end: cst.getNumDimVars());
1284 if (kind == VarKind::Symbol)
1285 return areVarsUnique(cst, start: cst.getNumDimVars(),
1286 end: cst.getNumDimAndSymbolVars());
1287 llvm_unreachable("Unexpected VarKind");
1288}
1289
1290/// Merge and align the variables of A and B starting at 'offset', so that
1291/// both constraint systems get the union of the contained variables that is
1292/// dimension-wise and symbol-wise unique; both constraint systems are updated
1293/// so that they have the union of all variables, with A's original
1294/// variables appearing first followed by any of B's variables that didn't
1295/// appear in A. Local variables in B that have the same division
1296/// representation as local variables in A are merged into one. We allow A
1297/// and B to have non-unique values for their variables; in such cases, they are
1298/// still aligned with the variables appearing first aligned with those
1299/// appearing first in the other system from left to right.
1300// E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
1301// Output: both A, B have (%i, %j, %k) [%M, %N, %P]
1302static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a,
1303 FlatLinearValueConstraints *b) {
1304 assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars());
1305
1306 assert(llvm::all_of(
1307 llvm::drop_begin(a->getMaybeValues(), offset),
1308 [](const std::optional<Value> &var) { return var.has_value(); }));
1309
1310 assert(llvm::all_of(
1311 llvm::drop_begin(b->getMaybeValues(), offset),
1312 [](const std::optional<Value> &var) { return var.has_value(); }));
1313
1314 SmallVector<Value, 4> aDimValues;
1315 a->getValues(start: offset, end: a->getNumDimVars(), values: &aDimValues);
1316
1317 {
1318 // Merge dims from A into B.
1319 unsigned d = offset;
1320 for (Value aDimValue : aDimValues) {
1321 unsigned loc;
1322 // Find from the position `d` since we'd like to also consider the
1323 // possibility of multiple variables with the same `Value`. We align with
1324 // the next appearing one.
1325 if (b->findVar(val: aDimValue, pos: &loc, offset: d)) {
1326 assert(loc >= offset && "A's dim appears in B's aligned range");
1327 assert(loc < b->getNumDimVars() &&
1328 "A's dim appears in B's non-dim position");
1329 b->swapVar(posA: d, posB: loc);
1330 } else {
1331 b->insertDimVar(pos: d, vals: aDimValue);
1332 }
1333 d++;
1334 }
1335 // Dimensions that are in B, but not in A, are added at the end.
1336 for (unsigned t = a->getNumDimVars(), e = b->getNumDimVars(); t < e; t++) {
1337 a->appendDimVar(vals: b->getValue(pos: t));
1338 }
1339 assert(a->getNumDimVars() == b->getNumDimVars() &&
1340 "expected same number of dims");
1341 }
1342
1343 // Merge and align symbols of A and B
1344 a->mergeSymbolVars(other&: *b);
1345 // Merge and align locals of A and B
1346 a->mergeLocalVars(other&: *b);
1347
1348 assert(areVarsAligned(*a, *b) && "IDs expected to be aligned");
1349}
1350
1351// Call 'mergeAndAlignVars' to align constraint systems of 'this' and 'other'.
1352void FlatLinearValueConstraints::mergeAndAlignVarsWithOther(
1353 unsigned offset, FlatLinearValueConstraints *other) {
1354 mergeAndAlignVars(offset, a: this, b: other);
1355}
1356
1357/// Merge and align symbols of `this` and `other` such that both get union of
1358/// of symbols. Existing symbols need not be unique; they will be aligned from
1359/// left to right with duplicates aligned in the same order. Symbols with Value
1360/// as `None` are considered to be inequal to all other symbols.
1361void FlatLinearValueConstraints::mergeSymbolVars(
1362 FlatLinearValueConstraints &other) {
1363
1364 SmallVector<Value, 4> aSymValues;
1365 getValues(start: getNumDimVars(), end: getNumDimAndSymbolVars(), values: &aSymValues);
1366
1367 // Merge symbols: merge symbols into `other` first from `this`.
1368 unsigned s = other.getNumDimVars();
1369 for (Value aSymValue : aSymValues) {
1370 unsigned loc;
1371 // If the var is a symbol in `other`, then align it, otherwise assume that
1372 // it is a new symbol. Search in `other` starting at position `s` since the
1373 // left of it is aligned.
1374 if (other.findVar(val: aSymValue, pos: &loc, offset: s) && loc >= other.getNumDimVars() &&
1375 loc < other.getNumDimAndSymbolVars())
1376 other.swapVar(posA: s, posB: loc);
1377 else
1378 other.insertSymbolVar(pos: s - other.getNumDimVars(), vals: aSymValue);
1379 s++;
1380 }
1381
1382 // Symbols that are in other, but not in this, are added at the end.
1383 for (unsigned t = other.getNumDimVars() + getNumSymbolVars(),
1384 e = other.getNumDimAndSymbolVars();
1385 t < e; t++)
1386 insertSymbolVar(pos: getNumSymbolVars(), vals: other.getValue(pos: t));
1387
1388 assert(getNumSymbolVars() == other.getNumSymbolVars() &&
1389 "expected same number of symbols");
1390}
1391
1392void FlatLinearValueConstraints::removeVarRange(VarKind kind, unsigned varStart,
1393 unsigned varLimit) {
1394 IntegerPolyhedron::removeVarRange(kind, varStart, varLimit);
1395}
1396
1397AffineMap
1398FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
1399 ValueRange operands) const {
1400 assert(map.getNumInputs() == operands.size() && "number of inputs mismatch");
1401
1402 SmallVector<Value> dims, syms;
1403#ifndef NDEBUG
1404 SmallVector<Value> newSyms;
1405 SmallVector<Value> *newSymsPtr = &newSyms;
1406#else
1407 SmallVector<Value> *newSymsPtr = nullptr;
1408#endif // NDEBUG
1409
1410 dims.reserve(N: getNumDimVars());
1411 syms.reserve(N: getNumSymbolVars());
1412 for (unsigned i = 0, e = getNumVarKind(kind: VarKind::SetDim); i < e; ++i) {
1413 Identifier id = space.getId(kind: VarKind::SetDim, pos: i);
1414 dims.push_back(Elt: id.hasValue() ? Value(id.getValue<Value>()) : Value());
1415 }
1416 for (unsigned i = 0, e = getNumVarKind(kind: VarKind::Symbol); i < e; ++i) {
1417 Identifier id = space.getId(kind: VarKind::Symbol, pos: i);
1418 syms.push_back(Elt: id.hasValue() ? Value(id.getValue<Value>()) : Value());
1419 }
1420
1421 AffineMap alignedMap =
1422 alignAffineMapWithValues(map, operands, dims, syms, newSyms: newSymsPtr);
1423 // All symbols are already part of this FlatAffineValueConstraints.
1424 assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
1425 assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
1426 "unexpected new/missing symbols");
1427 return alignedMap;
1428}
1429
1430bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
1431 unsigned offset) const {
1432 SmallVector<std::optional<Value>> maybeValues = getMaybeValues();
1433 for (unsigned i = offset, e = maybeValues.size(); i < e; ++i)
1434 if (maybeValues[i] && maybeValues[i].value() == val) {
1435 *pos = i;
1436 return true;
1437 }
1438 return false;
1439}
1440
1441bool FlatLinearValueConstraints::containsVar(Value val) const {
1442 unsigned pos;
1443 return findVar(val, pos: &pos, offset: 0);
1444}
1445
1446void FlatLinearValueConstraints::addBound(BoundType type, Value val,
1447 int64_t value) {
1448 unsigned pos;
1449 if (!findVar(val, pos: &pos))
1450 // This is a pre-condition for this method.
1451 assert(0 && "var not found");
1452 addBound(type, pos, value);
1453}
1454
1455void FlatLinearConstraints::printSpace(raw_ostream &os) const {
1456 IntegerPolyhedron::printSpace(os);
1457 os << "(";
1458 for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++)
1459 os << "None\t";
1460 for (unsigned i = getVarKindOffset(kind: VarKind::Local),
1461 e = getVarKindEnd(kind: VarKind::Local);
1462 i < e; ++i)
1463 os << "Local\t";
1464 os << "const)\n";
1465}
1466
1467void FlatLinearValueConstraints::printSpace(raw_ostream &os) const {
1468 IntegerPolyhedron::printSpace(os);
1469 os << "(";
1470 for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) {
1471 if (hasValue(pos: i))
1472 os << "Value\t";
1473 else
1474 os << "None\t";
1475 }
1476 for (unsigned i = getVarKindOffset(kind: VarKind::Local),
1477 e = getVarKindEnd(kind: VarKind::Local);
1478 i < e; ++i)
1479 os << "Local\t";
1480 os << "const)\n";
1481}
1482
1483void FlatLinearValueConstraints::projectOut(Value val) {
1484 unsigned pos;
1485 bool ret = findVar(val, pos: &pos);
1486 assert(ret);
1487 (void)ret;
1488 fourierMotzkinEliminate(pos);
1489}
1490
1491LogicalResult FlatLinearValueConstraints::unionBoundingBox(
1492 const FlatLinearValueConstraints &otherCst) {
1493 assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch");
1494 SmallVector<std::optional<Value>> maybeValues = getMaybeValues(),
1495 otherMaybeValues =
1496 otherCst.getMaybeValues();
1497 assert(std::equal(maybeValues.begin(), maybeValues.begin() + getNumDimVars(),
1498 otherMaybeValues.begin(),
1499 otherMaybeValues.begin() + getNumDimVars()) &&
1500 "dim values mismatch");
1501 assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
1502 assert(getNumLocalVars() == 0 && "local vars not supported yet here");
1503
1504 // Align `other` to this.
1505 if (!areVarsAligned(a: *this, b: otherCst)) {
1506 FlatLinearValueConstraints otherCopy(otherCst);
1507 mergeAndAlignVars(/*offset=*/getNumDimVars(), a: this, b: &otherCopy);
1508 return IntegerPolyhedron::unionBoundingBox(other: otherCopy);
1509 }
1510
1511 return IntegerPolyhedron::unionBoundingBox(other: otherCst);
1512}
1513
1514//===----------------------------------------------------------------------===//
1515// Helper functions
1516//===----------------------------------------------------------------------===//
1517
1518AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
1519 ValueRange dims, ValueRange syms,
1520 SmallVector<Value> *newSyms) {
1521 assert(operands.size() == map.getNumInputs() &&
1522 "expected same number of operands and map inputs");
1523 MLIRContext *ctx = map.getContext();
1524 Builder builder(ctx);
1525 SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {});
1526 unsigned numSymbols = syms.size();
1527 SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {});
1528 if (newSyms) {
1529 newSyms->clear();
1530 newSyms->append(in_start: syms.begin(), in_end: syms.end());
1531 }
1532
1533 for (const auto &operand : llvm::enumerate(First&: operands)) {
1534 // Compute replacement dim/sym of operand.
1535 AffineExpr replacement;
1536 auto dimIt = llvm::find(Range&: dims, Val: operand.value());
1537 auto symIt = llvm::find(Range&: syms, Val: operand.value());
1538 if (dimIt != dims.end()) {
1539 replacement =
1540 builder.getAffineDimExpr(position: std::distance(first: dims.begin(), last: dimIt));
1541 } else if (symIt != syms.end()) {
1542 replacement =
1543 builder.getAffineSymbolExpr(position: std::distance(first: syms.begin(), last: symIt));
1544 } else {
1545 // This operand is neither a dimension nor a symbol. Add it as a new
1546 // symbol.
1547 replacement = builder.getAffineSymbolExpr(position: numSymbols++);
1548 if (newSyms)
1549 newSyms->push_back(Elt: operand.value());
1550 }
1551 // Add to corresponding replacements vector.
1552 if (operand.index() < map.getNumDims()) {
1553 dimReplacements[operand.index()] = replacement;
1554 } else {
1555 symReplacements[operand.index() - map.getNumDims()] = replacement;
1556 }
1557 }
1558
1559 return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
1560 numResultDims: dims.size(), numResultSyms: numSymbols);
1561}
1562
1563LogicalResult
1564mlir::getMultiAffineFunctionFromMap(AffineMap map,
1565 MultiAffineFunction &multiAff) {
1566 FlatLinearConstraints cst;
1567 std::vector<SmallVector<int64_t, 8>> flattenedExprs;
1568 LogicalResult result = getFlattenedAffineExprs(map, flattenedExprs: &flattenedExprs, localVarCst: &cst);
1569
1570 if (result.failed())
1571 return failure();
1572
1573 DivisionRepr divs = cst.getLocalReprs();
1574 assert(divs.hasAllReprs() &&
1575 "AffineMap cannot produce divs without local representation");
1576
1577 // TODO: We shouldn't have to do this conversion.
1578 Matrix<DynamicAPInt> mat(map.getNumResults(),
1579 map.getNumInputs() + divs.getNumDivs() + 1);
1580 for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i)
1581 for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j)
1582 mat(i, j) = flattenedExprs[i][j];
1583
1584 multiAff = MultiAffineFunction(
1585 PresburgerSpace::getRelationSpace(numDomain: map.getNumDims(), numRange: map.getNumResults(),
1586 numSymbols: map.getNumSymbols(), numLocals: divs.getNumDivs()),
1587 mat, divs);
1588
1589 return success();
1590}
1591

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Analysis/FlatLinearValueConstraints.cpp