1//===- FlatLinearValueConstraints.cpp - Linear Constraint -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis//FlatLinearValueConstraints.h"
10
11#include "mlir/Analysis/Presburger/PresburgerSpace.h"
12#include "mlir/Analysis/Presburger/Simplex.h"
13#include "mlir/Analysis/Presburger/Utils.h"
14#include "mlir/IR/AffineExprVisitor.h"
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/IntegerSet.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/InterleavedRange.h"
22#include "llvm/Support/raw_ostream.h"
23#include <optional>
24
25#define DEBUG_TYPE "flat-value-constraints"
26
27using namespace mlir;
28using namespace presburger;
29
30//===----------------------------------------------------------------------===//
31// AffineExprFlattener
32//===----------------------------------------------------------------------===//
33
34namespace {
35
36// See comments for SimpleAffineExprFlattener.
37// An AffineExprFlattenerWithLocalVars extends a SimpleAffineExprFlattener by
38// recording constraint information associated with mod's, floordiv's, and
39// ceildiv's in FlatLinearConstraints 'localVarCst'.
40struct AffineExprFlattener : public SimpleAffineExprFlattener {
41 using SimpleAffineExprFlattener::SimpleAffineExprFlattener;
42
43 // Constraints connecting newly introduced local variables (for mod's and
44 // div's) to existing (dimensional and symbolic) ones. These are always
45 // inequalities.
46 IntegerPolyhedron localVarCst;
47
48 AffineExprFlattener(unsigned nDims, unsigned nSymbols)
49 : SimpleAffineExprFlattener(nDims, nSymbols),
50 localVarCst(PresburgerSpace::getSetSpace(numDims: nDims, numSymbols: nSymbols)) {};
51
52private:
53 // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr).
54 // The local variable added is always a floordiv of a pure add/mul affine
55 // function of other variables, coefficients of which are specified in
56 // `dividend' and with respect to the positive constant `divisor'. localExpr
57 // is the simplified tree expression (AffineExpr) corresponding to the
58 // quantifier.
59 void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
60 AffineExpr localExpr) override {
61 SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
62 // Update localVarCst.
63 localVarCst.addLocalFloorDiv(dividend, divisor);
64 }
65
66 LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
67 ArrayRef<int64_t> rhs,
68 AffineExpr localExpr) override {
69 // AffineExprFlattener does not support semi-affine expressions.
70 return failure();
71 }
72};
73
74// A SemiAffineExprFlattener is an AffineExprFlattenerWithLocalVars that adds
75// conservative bounds for semi-affine expressions (given assumptions hold). If
76// the assumptions required to add the semi-affine bounds are found not to hold
77// the final constraints set will be empty/inconsistent. If the assumptions are
78// never contradicted the final bounds still only will be correct if the
79// assumptions hold.
80struct SemiAffineExprFlattener : public AffineExprFlattener {
81 using AffineExprFlattener::AffineExprFlattener;
82
83 LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
84 ArrayRef<int64_t> rhs,
85 AffineExpr localExpr) override {
86 auto result =
87 SimpleAffineExprFlattener::addLocalIdSemiAffine(lhs, rhs, localExpr);
88 assert(succeeded(result) &&
89 "unexpected failure in SimpleAffineExprFlattener");
90 (void)result;
91
92 if (localExpr.getKind() == AffineExprKind::Mod) {
93 // Given two numbers a and b, division is defined as:
94 //
95 // a = bq + r
96 // 0 <= r < |b| (where |x| is the absolute value of x)
97 //
98 // q = a floordiv b
99 // r = a mod b
100
101 // Add a new local variable (r) to represent the mod.
102 unsigned rPos = localVarCst.appendVar(kind: VarKind::Local);
103
104 // r >= 0 (Can ALWAYS be added)
105 localVarCst.addBound(type: BoundType::LB, pos: rPos, value: 0);
106
107 // r < b (Can be added if b > 0, which we assume here)
108 ArrayRef<int64_t> b = rhs;
109 SmallVector<int64_t> bSubR(b);
110 bSubR.insert(I: bSubR.begin() + rPos, Elt: -1);
111 // Note: bSubR = b - r
112 // So this adds the bound b - r >= 1 (equivalent to r < b)
113 localVarCst.addBound(type: BoundType::LB, expr: bSubR, value: 1);
114
115 // Note: The assumption of b > 0 is based on the affine expression docs,
116 // which state "RHS of mod is always a constant or a symbolic expression
117 // with a positive value." (see AffineExprKind in AffineExpr.h). If this
118 // assumption does not hold constraints (added above) are a contradiction.
119
120 return success();
121 }
122
123 // TODO: Support other semi-affine expressions.
124 return failure();
125 }
126};
127
128} // namespace
129
130// Flattens the expressions in map. Returns failure if 'expr' was unable to be
131// flattened. For example two specific cases:
132// 1. an unhandled semi-affine expressions is found.
133// 2. has poison expression (i.e., division by zero).
134static LogicalResult
135getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
136 unsigned numSymbols,
137 std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
138 FlatLinearConstraints *localVarCst,
139 bool addConservativeSemiAffineBounds = false) {
140 if (exprs.empty()) {
141 if (localVarCst)
142 *localVarCst = FlatLinearConstraints(numDims, numSymbols);
143 return success();
144 }
145
146 auto flattenExprs = [&](AffineExprFlattener &flattener) -> LogicalResult {
147 // Use the same flattener to simplify each expression successively. This way
148 // local variables / expressions are shared.
149 for (auto expr : exprs) {
150 auto flattenResult = flattener.walkPostOrder(expr);
151 if (failed(Result: flattenResult))
152 return failure();
153 }
154
155 assert(flattener.operandExprStack.size() == exprs.size());
156 flattenedExprs->clear();
157 flattenedExprs->assign(first: flattener.operandExprStack.begin(),
158 last: flattener.operandExprStack.end());
159
160 if (localVarCst)
161 localVarCst->clearAndCopyFrom(other: flattener.localVarCst);
162
163 return success();
164 };
165
166 if (addConservativeSemiAffineBounds) {
167 SemiAffineExprFlattener flattener(numDims, numSymbols);
168 return flattenExprs(flattener);
169 }
170
171 AffineExprFlattener flattener(numDims, numSymbols);
172 return flattenExprs(flattener);
173}
174
175// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
176// be flattened (an unhandled semi-affine was found).
177LogicalResult mlir::getFlattenedAffineExpr(
178 AffineExpr expr, unsigned numDims, unsigned numSymbols,
179 SmallVectorImpl<int64_t> *flattenedExpr, FlatLinearConstraints *localVarCst,
180 bool addConservativeSemiAffineBounds) {
181 std::vector<SmallVector<int64_t, 8>> flattenedExprs;
182 LogicalResult ret =
183 ::getFlattenedAffineExprs(exprs: {expr}, numDims, numSymbols, flattenedExprs: &flattenedExprs,
184 localVarCst, addConservativeSemiAffineBounds);
185 *flattenedExpr = flattenedExprs[0];
186 return ret;
187}
188
189/// Flattens the expressions in map. Returns failure if 'expr' was unable to be
190/// flattened (i.e., an unhandled semi-affine was found).
191LogicalResult mlir::getFlattenedAffineExprs(
192 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
193 FlatLinearConstraints *localVarCst, bool addConservativeSemiAffineBounds) {
194 if (map.getNumResults() == 0) {
195 if (localVarCst)
196 *localVarCst =
197 FlatLinearConstraints(map.getNumDims(), map.getNumSymbols());
198 return success();
199 }
200 return ::getFlattenedAffineExprs(
201 exprs: map.getResults(), numDims: map.getNumDims(), numSymbols: map.getNumSymbols(), flattenedExprs,
202 localVarCst, addConservativeSemiAffineBounds);
203}
204
205LogicalResult mlir::getFlattenedAffineExprs(
206 IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
207 FlatLinearConstraints *localVarCst) {
208 if (set.getNumConstraints() == 0) {
209 if (localVarCst)
210 *localVarCst =
211 FlatLinearConstraints(set.getNumDims(), set.getNumSymbols());
212 return success();
213 }
214 return ::getFlattenedAffineExprs(exprs: set.getConstraints(), numDims: set.getNumDims(),
215 numSymbols: set.getNumSymbols(), flattenedExprs,
216 localVarCst);
217}
218
219//===----------------------------------------------------------------------===//
220// FlatLinearConstraints
221//===----------------------------------------------------------------------===//
222
223// Similar to `composeMap` except that no Values need be associated with the
224// constraint system nor are they looked at -- the dimensions and symbols of
225// `other` are expected to correspond 1:1 to `this` system.
226LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) {
227 assert(other.getNumDims() == getNumDimVars() && "dim mismatch");
228 assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
229
230 std::vector<SmallVector<int64_t, 8>> flatExprs;
231 if (failed(Result: flattenAlignedMapAndMergeLocals(map: other, flattenedExprs: &flatExprs)))
232 return failure();
233 assert(flatExprs.size() == other.getNumResults());
234
235 // Add dimensions corresponding to the map's results.
236 insertDimVar(/*pos=*/0, /*num=*/other.getNumResults());
237
238 // We add one equality for each result connecting the result dim of the map to
239 // the other variables.
240 // E.g.: if the expression is 16*i0 + i1, and this is the r^th
241 // iteration/result of the value map, we are adding the equality:
242 // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
243 // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
244 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
245 const auto &flatExpr = flatExprs[r];
246 assert(flatExpr.size() >= other.getNumInputs() + 1);
247
248 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
249 // Set the coefficient for this result to one.
250 eqToAdd[r] = 1;
251
252 // Dims and symbols.
253 for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
254 // Negate `eq[r]` since the newly added dimension will be set to this one.
255 eqToAdd[e + i] = -flatExpr[i];
256 }
257 // Local columns of `eq` are at the beginning.
258 unsigned j = getNumDimVars() + getNumSymbolVars();
259 unsigned end = flatExpr.size() - 1;
260 for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
261 eqToAdd[j] = -flatExpr[i];
262 }
263
264 // Constant term.
265 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
266
267 // Add the equality connecting the result of the map to this constraint set.
268 addEquality(eq: eqToAdd);
269 }
270
271 return success();
272}
273
274// Determine whether the variable at 'pos' (say var_r) can be expressed as
275// modulo of another known variable (say var_n) w.r.t a constant. For example,
276// if the following constraints hold true:
277// ```
278// 0 <= var_r <= divisor - 1
279// var_n - (divisor * q_expr) = var_r
280// ```
281// where `var_n` is a known variable (called dividend), and `q_expr` is an
282// `AffineExpr` (called the quotient expression), `var_r` can be written as:
283//
284// `var_r = var_n mod divisor`.
285//
286// Additionally, in a special case of the above constaints where `q_expr` is an
287// variable itself that is not yet known (say `var_q`), it can be written as a
288// floordiv in the following way:
289//
290// `var_q = var_n floordiv divisor`.
291//
292// First 'num' dimensional variables starting at 'offset' are
293// derived/to-be-derived in terms of the remaining variables. The remaining
294// variables are assigned trivial affine expressions in `memo`. For example,
295// memo is initilized as follows for a `cst` with 5 dims, when offset=2, num=2:
296// memo ==> d0 d1 . . d2 ...
297// cst ==> c0 c1 c2 c3 c4 ...
298//
299// Returns true if the above mod or floordiv are detected, updating 'memo' with
300// these new expressions. Returns false otherwise.
301static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos,
302 unsigned offset, unsigned num, int64_t lbConst,
303 int64_t ubConst, MLIRContext *context,
304 SmallVectorImpl<AffineExpr> &memo) {
305 assert(pos < cst.getNumVars() && "invalid position");
306
307 // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can
308 // be determined.
309 if (lbConst != 0 || ubConst < 1)
310 return false;
311 int64_t divisor = ubConst + 1;
312
313 // Check for the aforementioned conditions in each equality.
314 for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
315 curEquality < numEqualities; curEquality++) {
316 int64_t coefficientAtPos = cst.atEq64(i: curEquality, j: pos);
317 // If current equality does not involve `var_r`, continue to the next
318 // equality.
319 if (coefficientAtPos == 0)
320 continue;
321
322 // Constant term should be 0 in this equality.
323 if (cst.atEq64(i: curEquality, j: cst.getNumCols() - 1) != 0)
324 continue;
325
326 // Traverse through the equality and construct the dividend expression
327 // `dividendExpr`, to contain all the variables which are known and are
328 // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
329 // `dividendExpr` gets simplified into a single variable `var_n` discussed
330 // above.
331 auto dividendExpr = getAffineConstantExpr(constant: 0, context);
332
333 // Track the terms that go into quotient expression, later used to detect
334 // additional floordiv.
335 unsigned quotientCount = 0;
336 int quotientPosition = -1;
337 int quotientSign = 1;
338
339 // Consider each term in the current equality.
340 unsigned curVar, e;
341 for (curVar = 0, e = cst.getNumDimAndSymbolVars(); curVar < e; ++curVar) {
342 // Ignore var_r.
343 if (curVar == pos)
344 continue;
345 int64_t coefficientOfCurVar = cst.atEq64(i: curEquality, j: curVar);
346 // Ignore vars that do not contribute to the current equality.
347 if (coefficientOfCurVar == 0)
348 continue;
349 // Check if the current var goes into the quotient expression.
350 if (coefficientOfCurVar % (divisor * coefficientAtPos) == 0) {
351 quotientCount++;
352 quotientPosition = curVar;
353 quotientSign = (coefficientOfCurVar * coefficientAtPos) > 0 ? 1 : -1;
354 continue;
355 }
356 // Variables that are part of dividendExpr should be known.
357 if (!memo[curVar])
358 break;
359 // Append the current variable to the dividend expression.
360 dividendExpr = dividendExpr + memo[curVar] * coefficientOfCurVar;
361 }
362
363 // Can't construct expression as it depends on a yet uncomputed var.
364 if (curVar < e)
365 continue;
366
367 // Express `var_r` in terms of the other vars collected so far.
368 if (coefficientAtPos > 0)
369 dividendExpr = (-dividendExpr).floorDiv(v: coefficientAtPos);
370 else
371 dividendExpr = dividendExpr.floorDiv(v: -coefficientAtPos);
372
373 // Simplify the expression.
374 dividendExpr = simplifyAffineExpr(expr: dividendExpr, numDims: cst.getNumDimVars(),
375 numSymbols: cst.getNumSymbolVars());
376 // Only if the final dividend expression is just a single var (which we call
377 // `var_n`), we can proceed.
378 // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
379 // to dims themselves.
380 auto dimExpr = dyn_cast<AffineDimExpr>(Val&: dividendExpr);
381 if (!dimExpr)
382 continue;
383
384 // Express `var_r` as `var_n % divisor` and store the expression in `memo`.
385 if (quotientCount >= 1) {
386 // Find the column corresponding to `dimExpr`. `num` columns starting at
387 // `offset` correspond to previously unknown variables. The column
388 // corresponding to the trivially known `dimExpr` can be on either side
389 // of these.
390 unsigned dimExprPos = dimExpr.getPosition();
391 unsigned dimExprCol = dimExprPos < offset ? dimExprPos : dimExprPos + num;
392 auto ub = cst.getConstantBound64(type: BoundType::UB, pos: dimExprCol);
393 // If `var_n` has an upperbound that is less than the divisor, mod can be
394 // eliminated altogether.
395 if (ub && *ub < divisor)
396 memo[pos] = dimExpr;
397 else
398 memo[pos] = dimExpr % divisor;
399 // If a unique quotient `var_q` was seen, it can be expressed as
400 // `var_n floordiv divisor`.
401 if (quotientCount == 1 && !memo[quotientPosition])
402 memo[quotientPosition] = dimExpr.floorDiv(v: divisor) * quotientSign;
403
404 return true;
405 }
406 }
407 return false;
408}
409
410/// Check if the pos^th variable can be expressed as a floordiv of an affine
411/// function of other variables (where the divisor is a positive constant)
412/// given the initial set of expressions in `exprs`. If it can be, the
413/// corresponding position in `exprs` is set as the detected affine expr. For
414/// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can
415/// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
416/// <= i <= 32q + 31 => q = i floordiv 32.
417static bool detectAsFloorDiv(const FlatLinearConstraints &cst, unsigned pos,
418 MLIRContext *context,
419 SmallVectorImpl<AffineExpr> &exprs) {
420 assert(pos < cst.getNumVars() && "invalid position");
421
422 // Get upper-lower bound pair for this variable.
423 SmallVector<bool, 8> foundRepr(cst.getNumVars(), false);
424 for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i)
425 if (exprs[i])
426 foundRepr[i] = true;
427
428 SmallVector<int64_t, 8> dividend(cst.getNumCols());
429 unsigned divisor;
430 auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor);
431
432 // No upper-lower bound pair found for this var.
433 if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality)
434 return false;
435
436 // Construct the dividend expression.
437 auto dividendExpr = getAffineConstantExpr(constant: dividend.back(), context);
438 for (unsigned c = 0, f = cst.getNumVars(); c < f; c++)
439 if (dividend[c] != 0)
440 dividendExpr = dividendExpr + dividend[c] * exprs[c];
441
442 // Successfully detected the floordiv.
443 exprs[pos] = dividendExpr.floorDiv(v: divisor);
444 return true;
445}
446
447void FlatLinearConstraints::dumpRow(ArrayRef<int64_t> row,
448 bool fixedColWidth) const {
449 unsigned ncols = getNumCols();
450 bool firstNonZero = true;
451 for (unsigned j = 0; j < ncols; j++) {
452 if (j == ncols - 1) {
453 // Constant.
454 if (row[j] == 0 && !firstNonZero) {
455 if (fixedColWidth)
456 llvm::errs().indent(NumSpaces: 7);
457 } else {
458 llvm::errs() << ((row[j] >= 0) ? "+ " : "") << row[j] << ' ';
459 }
460 } else {
461 std::string var = std::string("c_") + std::to_string(val: j);
462 if (row[j] == 1)
463 llvm::errs() << "+ " << var << ' ';
464 else if (row[j] == -1)
465 llvm::errs() << "- " << var << ' ';
466 else if (row[j] >= 2)
467 llvm::errs() << "+ " << row[j] << '*' << var << ' ';
468 else if (row[j] <= -2)
469 llvm::errs() << "- " << -row[j] << '*' << var << ' ';
470 else if (fixedColWidth)
471 // Zero coeff.
472 llvm::errs().indent(NumSpaces: 7);
473 if (row[j] != 0)
474 firstNonZero = false;
475 }
476 }
477}
478
479void FlatLinearConstraints::dumpPretty() const {
480 assert(hasConsistentState());
481 llvm::errs() << "Constraints (" << getNumDimVars() << " dims, "
482 << getNumSymbolVars() << " symbols, " << getNumLocalVars()
483 << " locals), (" << getNumConstraints() << " constraints)\n";
484 auto dumpConstraint = [&](unsigned rowPos, bool isEq) {
485 // Is it the first non-zero entry?
486 SmallVector<int64_t> row =
487 isEq ? getEquality64(idx: rowPos) : getInequality64(idx: rowPos);
488 dumpRow(row);
489 llvm::errs() << (isEq ? "=" : ">=") << " 0\n";
490 };
491
492 for (unsigned i = 0, e = getNumInequalities(); i < e; i++)
493 dumpConstraint(i, /*isEq=*/false);
494 for (unsigned i = 0, e = getNumEqualities(); i < e; i++)
495 dumpConstraint(i, /*isEq=*/true);
496 llvm::errs() << '\n';
497}
498
499std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound(
500 unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
501 ArrayRef<AffineExpr> localExprs, MLIRContext *context,
502 bool closedUB) const {
503 assert(pos + offset < getNumDimVars() && "invalid dim start pos");
504 assert(symStartPos >= (pos + offset) && "invalid sym start pos");
505 assert(getNumLocalVars() == localExprs.size() &&
506 "incorrect local exprs count");
507
508 SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
509 getLowerAndUpperBoundIndices(pos: pos + offset, lbIndices: &lbIndices, ubIndices: &ubIndices, eqIndices: &eqIndices,
510 offset, num);
511
512 /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
513 auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
514 b.clear();
515 for (unsigned i = 0, e = a.size(); i < e; ++i) {
516 if (i < offset || i >= offset + num)
517 b.push_back(Elt: a[i]);
518 }
519 };
520
521 SmallVector<int64_t, 8> lb, ub;
522 SmallVector<AffineExpr, 4> lbExprs;
523 unsigned dimCount = symStartPos - num;
524 unsigned symCount = getNumDimAndSymbolVars() - symStartPos;
525 lbExprs.reserve(N: lbIndices.size() + eqIndices.size());
526 // Lower bound expressions.
527 for (auto idx : lbIndices) {
528 auto ineq = getInequality64(idx);
529 // Extract the lower bound (in terms of other coeff's + const), i.e., if
530 // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
531 // - 1.
532 addCoeffs(ineq, lb);
533 std::transform(first: lb.begin(), last: lb.end(), result: lb.begin(), unary_op: std::negate<int64_t>());
534 auto expr =
535 getAffineExprFromFlatForm(flatExprs: lb, numDims: dimCount, numSymbols: symCount, localExprs, context);
536 // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
537 int64_t divisor = std::abs(i: ineq[pos + offset]);
538 expr = (expr + divisor - 1).floorDiv(v: divisor);
539 lbExprs.push_back(Elt: expr);
540 }
541
542 SmallVector<AffineExpr, 4> ubExprs;
543 ubExprs.reserve(N: ubIndices.size() + eqIndices.size());
544 // Upper bound expressions.
545 for (auto idx : ubIndices) {
546 auto ineq = getInequality64(idx);
547 // Extract the upper bound (in terms of other coeff's + const).
548 addCoeffs(ineq, ub);
549 auto expr =
550 getAffineExprFromFlatForm(flatExprs: ub, numDims: dimCount, numSymbols: symCount, localExprs, context);
551 expr = expr.floorDiv(v: std::abs(i: ineq[pos + offset]));
552 int64_t ubAdjustment = closedUB ? 0 : 1;
553 ubExprs.push_back(Elt: expr + ubAdjustment);
554 }
555
556 // Equalities. It's both a lower and a upper bound.
557 SmallVector<int64_t, 4> b;
558 for (auto idx : eqIndices) {
559 auto eq = getEquality64(idx);
560 addCoeffs(eq, b);
561 if (eq[pos + offset] > 0)
562 std::transform(first: b.begin(), last: b.end(), result: b.begin(), unary_op: std::negate<int64_t>());
563
564 // Extract the upper bound (in terms of other coeff's + const).
565 auto expr =
566 getAffineExprFromFlatForm(flatExprs: b, numDims: dimCount, numSymbols: symCount, localExprs, context);
567 expr = expr.floorDiv(v: std::abs(i: eq[pos + offset]));
568 // Upper bound is exclusive.
569 ubExprs.push_back(Elt: expr + 1);
570 // Lower bound.
571 expr =
572 getAffineExprFromFlatForm(flatExprs: b, numDims: dimCount, numSymbols: symCount, localExprs, context);
573 expr = expr.ceilDiv(v: std::abs(i: eq[pos + offset]));
574 lbExprs.push_back(Elt: expr);
575 }
576
577 auto lbMap = AffineMap::get(dimCount, symbolCount: symCount, results: lbExprs, context);
578 auto ubMap = AffineMap::get(dimCount, symbolCount: symCount, results: ubExprs, context);
579
580 return {lbMap, ubMap};
581}
582
583/// Express the pos^th identifier of `cst` as an affine expression in
584/// terms of other identifiers, if they are available in `exprs`, using the
585/// equality at position `idx` in `cs`t. Populates `exprs` with such an
586/// expression if possible, and return true. Returns false otherwise.
587static bool detectAsExpr(const FlatLinearConstraints &cst, unsigned pos,
588 unsigned idx, MLIRContext *context,
589 SmallVectorImpl<AffineExpr> &exprs) {
590 // Initialize with a `0` expression.
591 auto expr = getAffineConstantExpr(constant: 0, context);
592
593 // Traverse `idx`th equality and construct the possible affine expression in
594 // terms of known identifiers.
595 unsigned j, e;
596 for (j = 0, e = cst.getNumVars(); j < e; ++j) {
597 if (j == pos)
598 continue;
599 int64_t c = cst.atEq64(i: idx, j);
600 if (c == 0)
601 continue;
602 // If any of the involved IDs hasn't been found yet, we can't proceed.
603 if (!exprs[j])
604 break;
605 expr = expr + exprs[j] * c;
606 }
607 if (j < e)
608 // Can't construct expression as it depends on a yet uncomputed
609 // identifier.
610 return false;
611
612 // Add constant term to AffineExpr.
613 expr = expr + cst.atEq64(i: idx, j: cst.getNumVars());
614 int64_t vPos = cst.atEq64(i: idx, j: pos);
615 assert(vPos != 0 && "expected non-zero here");
616 if (vPos > 0)
617 expr = (-expr).floorDiv(v: vPos);
618 else
619 // vPos < 0.
620 expr = expr.floorDiv(v: -vPos);
621 // Successfully constructed expression.
622 exprs[pos] = expr;
623 return true;
624}
625
626/// Compute a representation of `num` identifiers starting at `offset` in `cst`
627/// as affine expressions involving other known identifiers. Each identifier's
628/// expression (in terms of known identifiers) is populated into `memo`.
629static void computeUnknownVars(const FlatLinearConstraints &cst,
630 MLIRContext *context, unsigned offset,
631 unsigned num,
632 SmallVectorImpl<AffineExpr> &memo) {
633 // Initialize dimensional and symbolic variables.
634 for (unsigned i = 0, e = cst.getNumDimVars(); i < e; i++) {
635 if (i < offset)
636 memo[i] = getAffineDimExpr(position: i, context);
637 else if (i >= offset + num)
638 memo[i] = getAffineDimExpr(position: i - num, context);
639 }
640 for (unsigned i = cst.getNumDimVars(), e = cst.getNumDimAndSymbolVars();
641 i < e; i++)
642 memo[i] = getAffineSymbolExpr(position: i - cst.getNumDimVars(), context);
643
644 bool changed;
645 do {
646 changed = false;
647 // Identify yet unknown variables as constants or mod's / floordiv's of
648 // other variables if possible.
649 for (unsigned pos = 0, f = cst.getNumVars(); pos < f; pos++) {
650 if (memo[pos])
651 continue;
652
653 auto lbConst = cst.getConstantBound64(type: BoundType::LB, pos);
654 auto ubConst = cst.getConstantBound64(type: BoundType::UB, pos);
655 if (lbConst.has_value() && ubConst.has_value()) {
656 // Detect equality to a constant.
657 if (*lbConst == *ubConst) {
658 memo[pos] = getAffineConstantExpr(constant: *lbConst, context);
659 changed = true;
660 continue;
661 }
662
663 // Detect a variable as modulo of another variable w.r.t a
664 // constant.
665 if (detectAsMod(cst, pos, offset, num, lbConst: *lbConst, ubConst: *ubConst, context,
666 memo)) {
667 changed = true;
668 continue;
669 }
670 }
671
672 // Detect a variable as a floordiv of an affine function of other
673 // variables (divisor is a positive constant).
674 if (detectAsFloorDiv(cst, pos, context, exprs&: memo)) {
675 changed = true;
676 continue;
677 }
678
679 // Detect a variable as an expression of other variables.
680 std::optional<unsigned> idx;
681 if (!(idx = cst.findConstraintWithNonZeroAt(colIdx: pos, /*isEq=*/true)))
682 continue;
683
684 if (detectAsExpr(cst, pos, idx: *idx, context, exprs&: memo)) {
685 changed = true;
686 continue;
687 }
688 }
689 // This loop is guaranteed to reach a fixed point - since once an
690 // variable's explicit form is computed (in memo[pos]), it's not updated
691 // again.
692 } while (changed);
693}
694
695/// Computes the lower and upper bounds of the first 'num' dimensional
696/// variables (starting at 'offset') as affine maps of the remaining
697/// variables (dimensional and symbolic variables). Local variables are
698/// themselves explicitly computed as affine functions of other variables in
699/// this process if needed.
700void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num,
701 MLIRContext *context,
702 SmallVectorImpl<AffineMap> *lbMaps,
703 SmallVectorImpl<AffineMap> *ubMaps,
704 bool closedUB) {
705 assert(offset + num <= getNumDimVars() && "invalid range");
706
707 // Basic simplification.
708 normalizeConstraintsByGCD();
709
710 LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for variables at positions ["
711 << offset << ", " << offset + num << ")\n");
712 LLVM_DEBUG(dumpPretty());
713
714 // Record computed/detected variables.
715 SmallVector<AffineExpr, 8> memo(getNumVars());
716 computeUnknownVars(cst: *this, context, offset, num, memo);
717
718 int64_t ubAdjustment = closedUB ? 0 : 1;
719
720 // Set the lower and upper bound maps for all the variables that were
721 // computed as affine expressions of the rest as the "detected expr" and
722 // "detected expr + 1" respectively; set the undetected ones to null.
723 std::optional<FlatLinearConstraints> tmpClone;
724 for (unsigned pos = 0; pos < num; pos++) {
725 unsigned numMapDims = getNumDimVars() - num;
726 unsigned numMapSymbols = getNumSymbolVars();
727 AffineExpr expr = memo[pos + offset];
728 if (expr)
729 expr = simplifyAffineExpr(expr, numDims: numMapDims, numSymbols: numMapSymbols);
730
731 AffineMap &lbMap = (*lbMaps)[pos];
732 AffineMap &ubMap = (*ubMaps)[pos];
733
734 if (expr) {
735 lbMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols, result: expr);
736 ubMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols, result: expr + ubAdjustment);
737 } else {
738 // TODO: Whenever there are local variables in the dependence
739 // constraints, we'll conservatively over-approximate, since we don't
740 // always explicitly compute them above (in the while loop).
741 if (getNumLocalVars() == 0) {
742 // Work on a copy so that we don't update this constraint system.
743 if (!tmpClone) {
744 tmpClone.emplace(args: FlatLinearConstraints(*this));
745 // Removing redundant inequalities is necessary so that we don't get
746 // redundant loop bounds.
747 tmpClone->removeRedundantInequalities();
748 }
749 std::tie(args&: lbMap, args&: ubMap) = tmpClone->getLowerAndUpperBound(
750 pos, offset, num, symStartPos: getNumDimVars(), /*localExprs=*/{}, context,
751 closedUB);
752 }
753
754 // If the above fails, we'll just use the constant lower bound and the
755 // constant upper bound (if they exist) as the slice bounds.
756 // TODO: being conservative for the moment in cases that
757 // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
758 // fixed (b/126426796).
759 if (!lbMap || lbMap.getNumResults() != 1) {
760 LLVM_DEBUG(llvm::dbgs()
761 << "WARNING: Potentially over-approximating slice lb\n");
762 auto lbConst = getConstantBound64(type: BoundType::LB, pos: pos + offset);
763 if (lbConst.has_value()) {
764 lbMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols,
765 result: getAffineConstantExpr(constant: *lbConst, context));
766 }
767 }
768 if (!ubMap || ubMap.getNumResults() != 1) {
769 LLVM_DEBUG(llvm::dbgs()
770 << "WARNING: Potentially over-approximating slice ub\n");
771 auto ubConst = getConstantBound64(type: BoundType::UB, pos: pos + offset);
772 if (ubConst.has_value()) {
773 ubMap = AffineMap::get(
774 dimCount: numMapDims, symbolCount: numMapSymbols,
775 result: getAffineConstantExpr(constant: *ubConst + ubAdjustment, context));
776 }
777 }
778 }
779
780 LLVM_DEBUG(llvm::dbgs() << "Slice bounds:\n");
781 LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos + offset)
782 << ", expr: " << lbMap << '\n');
783 LLVM_DEBUG(llvm::dbgs() << "ub map for pos = " << Twine(pos + offset)
784 << ", expr: " << ubMap << '\n');
785 }
786}
787
788LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
789 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
790 bool addConservativeSemiAffineBounds) {
791 FlatLinearConstraints localCst;
792 if (failed(Result: getFlattenedAffineExprs(map, flattenedExprs, localVarCst: &localCst,
793 addConservativeSemiAffineBounds))) {
794 LLVM_DEBUG(llvm::dbgs()
795 << "composition unimplemented for semi-affine maps\n");
796 return failure();
797 }
798
799 // Add localCst information.
800 if (localCst.getNumLocalVars() > 0) {
801 unsigned numLocalVars = getNumLocalVars();
802 // Insert local dims of localCst at the beginning.
803 insertLocalVar(/*pos=*/0, /*num=*/localCst.getNumLocalVars());
804 // Insert local dims of `this` at the end of localCst.
805 localCst.appendLocalVar(/*num=*/numLocalVars);
806 // Dimensions of localCst and this constraint set match. Append localCst to
807 // this constraint set.
808 append(other: localCst);
809 }
810
811 return success();
812}
813
814LogicalResult FlatLinearConstraints::addBound(
815 BoundType type, unsigned pos, AffineMap boundMap, bool isClosedBound,
816 AddConservativeSemiAffineBounds addSemiAffineBounds) {
817 assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch");
818 assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
819 assert(pos < getNumDimAndSymbolVars() && "invalid position");
820 assert((type != BoundType::EQ || isClosedBound) &&
821 "EQ bound must be closed.");
822
823 // Equality follows the logic of lower bound except that we add an equality
824 // instead of an inequality.
825 assert((type != BoundType::EQ || boundMap.getNumResults() == 1) &&
826 "single result expected");
827 bool lower = type == BoundType::LB || type == BoundType::EQ;
828
829 std::vector<SmallVector<int64_t, 8>> flatExprs;
830 if (failed(Result: flattenAlignedMapAndMergeLocals(
831 map: boundMap, flattenedExprs: &flatExprs,
832 addConservativeSemiAffineBounds: addSemiAffineBounds == AddConservativeSemiAffineBounds::Yes)))
833 return failure();
834 assert(flatExprs.size() == boundMap.getNumResults());
835
836 // Add one (in)equality for each result.
837 for (const auto &flatExpr : flatExprs) {
838 SmallVector<int64_t> ineq(getNumCols(), 0);
839 // Dims and symbols.
840 for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
841 ineq[j] = lower ? -flatExpr[j] : flatExpr[j];
842 }
843 // Invalid bound: pos appears in `boundMap`.
844 // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or
845 // its callers to prevent invalid bounds from being added.
846 if (ineq[pos] != 0)
847 continue;
848 ineq[pos] = lower ? 1 : -1;
849 // Local columns of `ineq` are at the beginning.
850 unsigned j = getNumDimVars() + getNumSymbolVars();
851 unsigned end = flatExpr.size() - 1;
852 for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
853 ineq[j] = lower ? -flatExpr[i] : flatExpr[i];
854 }
855 // Make the bound closed in if flatExpr is open. The inequality is always
856 // created in the upper bound form, so the adjustment is -1.
857 int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1;
858 // Constant term.
859 ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1]
860 : flatExpr[flatExpr.size() - 1]) +
861 boundAdjustment;
862 type == BoundType::EQ ? addEquality(eq: ineq) : addInequality(inEq: ineq);
863 }
864
865 return success();
866}
867
868LogicalResult FlatLinearConstraints::addBound(
869 BoundType type, unsigned pos, AffineMap boundMap,
870 AddConservativeSemiAffineBounds addSemiAffineBounds) {
871 return addBound(type, pos, boundMap,
872 /*isClosedBound=*/type != BoundType::UB, addSemiAffineBounds);
873}
874
875/// Compute an explicit representation for local vars. For all systems coming
876/// from MLIR integer sets, maps, or expressions where local vars were
877/// introduced to model floordivs and mods, this always succeeds.
878LogicalResult
879FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo,
880 MLIRContext *context) const {
881 unsigned numDims = getNumDimVars();
882 unsigned numSyms = getNumSymbolVars();
883
884 // Initialize dimensional and symbolic variables.
885 for (unsigned i = 0; i < numDims; i++)
886 memo[i] = getAffineDimExpr(position: i, context);
887 for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
888 memo[i] = getAffineSymbolExpr(position: i - numDims, context);
889
890 bool changed;
891 do {
892 // Each time `changed` is true at the end of this iteration, one or more
893 // local vars would have been detected as floordivs and set in memo; so the
894 // number of null entries in memo[...] strictly reduces; so this converges.
895 changed = false;
896 for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i)
897 if (!memo[numDims + numSyms + i] &&
898 detectAsFloorDiv(cst: *this, /*pos=*/numDims + numSyms + i, context, exprs&: memo))
899 changed = true;
900 } while (changed);
901
902 ArrayRef<AffineExpr> localExprs =
903 ArrayRef<AffineExpr>(memo).take_back(N: getNumLocalVars());
904 return success(
905 IsSuccess: llvm::all_of(Range&: localExprs, P: [](AffineExpr expr) { return expr; }));
906}
907
908/// Given an equality or inequality (`isEquality` used to disambiguate) of `cst`
909/// at `idx`, traverse and sum up `AffineExpr`s of all known ids other than the
910/// `pos`th. Known `AffineExpr`s are given in `exprs` (unknowns are null). If
911/// the equality/inequality contains any unknown id, return None. Otherwise
912/// return sum as `AffineExpr`.
913static std::optional<AffineExpr> getAsExpr(const FlatLinearConstraints &cst,
914 unsigned pos, MLIRContext *context,
915 ArrayRef<AffineExpr> exprs,
916 unsigned idx, bool isEquality) {
917 // Initialize with a `0` expression.
918 auto expr = getAffineConstantExpr(constant: 0, context);
919
920 SmallVector<int64_t, 8> row =
921 isEquality ? cst.getEquality64(idx) : cst.getInequality64(idx);
922
923 // Traverse `idx`th equality and construct the possible affine expression in
924 // terms of known identifiers.
925 unsigned j, e;
926 for (j = 0, e = cst.getNumVars(); j < e; ++j) {
927 if (j == pos)
928 continue;
929 int64_t c = row[j];
930 if (c == 0)
931 continue;
932 // If any of the involved IDs hasn't been found yet, we can't proceed.
933 if (!exprs[j])
934 break;
935 expr = expr + exprs[j] * c;
936 }
937 if (j < e)
938 // Can't construct expression as it depends on a yet uncomputed
939 // identifier.
940 return std::nullopt;
941
942 // Add constant term to AffineExpr.
943 expr = expr + row[cst.getNumVars()];
944 return expr;
945}
946
947std::optional<int64_t> FlatLinearConstraints::getConstantBoundOnDimSize(
948 MLIRContext *context, unsigned pos, AffineMap *lb, AffineMap *ub,
949 unsigned *minLbPos, unsigned *minUbPos) const {
950
951 assert(pos < getNumDimVars() && "Invalid identifier position");
952
953 auto freeOfUnknownLocalVars = [&](ArrayRef<int64_t> cst,
954 ArrayRef<AffineExpr> whiteListCols) {
955 for (int i = getNumDimAndSymbolVars(), e = cst.size() - 1; i < e; ++i) {
956 if (whiteListCols[i] && whiteListCols[i].isSymbolicOrConstant())
957 continue;
958 if (cst[i] != 0)
959 return false;
960 }
961 return true;
962 };
963
964 // Detect the necesary local variables first.
965 SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
966 (void)computeLocalVars(memo, context);
967
968 // Find an equality for 'pos'^th identifier that equates it to some function
969 // of the symbolic identifiers (+ constant).
970 int eqPos = findEqualityToConstant(pos, /*symbolic=*/true);
971 // If the equality involves a local var that can not be expressed as a
972 // symbolic or constant affine expression, we bail out.
973 if (eqPos != -1 && freeOfUnknownLocalVars(getEquality64(idx: eqPos), memo)) {
974 // This identifier can only take a single value.
975 if (lb && detectAsExpr(cst: *this, pos, idx: eqPos, context, exprs&: memo)) {
976 AffineExpr equalityExpr =
977 simplifyAffineExpr(expr: memo[pos], numDims: 0, numSymbols: getNumSymbolVars());
978 *lb = AffineMap::get(/*dimCount=*/0, symbolCount: getNumSymbolVars(), result: equalityExpr);
979 if (ub)
980 *ub = *lb;
981 }
982 if (minLbPos)
983 *minLbPos = eqPos;
984 if (minUbPos)
985 *minUbPos = eqPos;
986 return 1;
987 }
988
989 // Positions of constraints that are lower/upper bounds on the variable.
990 SmallVector<unsigned, 4> lbIndices, ubIndices;
991
992 // Note inequalities that give lower and upper bounds.
993 getLowerAndUpperBoundIndices(pos, lbIndices: &lbIndices, ubIndices: &ubIndices,
994 /*eqIndices=*/nullptr, /*offset=*/0,
995 /*num=*/getNumDimVars());
996
997 std::optional<int64_t> minDiff = std::nullopt;
998 unsigned minLbPosition = 0, minUbPosition = 0;
999 AffineExpr minLbExpr, minUbExpr;
1000
1001 // Traverse each lower bound and upper bound pair, to compute the difference
1002 // between them.
1003 for (unsigned ubPos : ubIndices) {
1004 // Construct sum of all ids other than `pos`th in the given upper bound row.
1005 std::optional<AffineExpr> maybeUbExpr =
1006 getAsExpr(cst: *this, pos, context, exprs: memo, idx: ubPos, /*isEquality=*/false);
1007 if (!maybeUbExpr.has_value() || !(*maybeUbExpr).isSymbolicOrConstant())
1008 continue;
1009
1010 // Canonical form of an inequality that constrains the upper bound on
1011 // an id `x_i` is of the form:
1012 // `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` <= -1.
1013 // Therefore the upper bound on `x_i` will be
1014 // `(
1015 // sum(c_j*x_j) where j != i
1016 // +
1017 // c_0
1018 // )
1019 // /
1020 // -(c_i)`. Divison here is a floorDiv.
1021 AffineExpr ubExpr = maybeUbExpr->floorDiv(v: -atIneq64(i: ubPos, j: pos));
1022 assert(-atIneq64(ubPos, pos) > 0 && "invalid upper bound index");
1023
1024 // Go over each lower bound.
1025 for (unsigned lbPos : lbIndices) {
1026 // Construct sum of all ids other than `pos`th in the given lower bound
1027 // row.
1028 std::optional<AffineExpr> maybeLbExpr =
1029 getAsExpr(cst: *this, pos, context, exprs: memo, idx: lbPos, /*isEquality=*/false);
1030 if (!maybeLbExpr.has_value() || !(*maybeLbExpr).isSymbolicOrConstant())
1031 continue;
1032
1033 // Canonical form of an inequality that is constraining the lower bound
1034 // on an id `x_i is of the form:
1035 // `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` >= 1.
1036 // Therefore upperBound on `x_i` will be
1037 // `-(
1038 // sum(c_j*x_j) where j != i
1039 // +
1040 // c_0
1041 // )
1042 // /
1043 // c_i`. Divison here is a ceilDiv.
1044 int64_t divisor = atIneq64(i: lbPos, j: pos);
1045 // We convert the `ceilDiv` for floordiv with the formula:
1046 // `expr ceildiv divisor is (expr + divisor - 1) floordiv divisor`,
1047 // since uniformly keeping divisons as `floorDiv` helps their
1048 // simplification.
1049 AffineExpr lbExpr = (-(*maybeLbExpr) + divisor - 1).floorDiv(v: divisor);
1050 assert(atIneq64(lbPos, pos) > 0 && "invalid lower bound index");
1051
1052 AffineExpr difference =
1053 simplifyAffineExpr(expr: ubExpr - lbExpr + 1, numDims: 0, numSymbols: getNumSymbolVars());
1054 // If the difference is not constant, ignore the lower bound - upper bound
1055 // pair.
1056 auto constantDiff = dyn_cast<AffineConstantExpr>(Val&: difference);
1057 if (!constantDiff)
1058 continue;
1059
1060 int64_t diffValue = constantDiff.getValue();
1061 // This bound is non-negative by definition.
1062 diffValue = std::max<int64_t>(a: diffValue, b: 0);
1063 if (!minDiff || diffValue < *minDiff) {
1064 minDiff = diffValue;
1065 minLbPosition = lbPos;
1066 minUbPosition = ubPos;
1067 minLbExpr = lbExpr;
1068 minUbExpr = ubExpr;
1069 }
1070 }
1071 }
1072
1073 // Populate outputs where available and needed.
1074 if (lb && minDiff) {
1075 *lb = AffineMap::get(/*dimCount=*/0, symbolCount: getNumSymbolVars(), result: minLbExpr);
1076 }
1077 if (ub)
1078 *ub = AffineMap::get(/*dimCount=*/0, symbolCount: getNumSymbolVars(), result: minUbExpr);
1079 if (minLbPos)
1080 *minLbPos = minLbPosition;
1081 if (minUbPos)
1082 *minUbPos = minUbPosition;
1083
1084 return minDiff;
1085}
1086
1087IntegerSet FlatLinearConstraints::getAsIntegerSet(MLIRContext *context) const {
1088 if (getNumConstraints() == 0)
1089 // Return universal set (always true): 0 == 0.
1090 return IntegerSet::get(dimCount: getNumDimVars(), symbolCount: getNumSymbolVars(),
1091 constraints: getAffineConstantExpr(/*constant=*/0, context),
1092 /*eqFlags=*/true);
1093
1094 // Construct local references.
1095 SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
1096
1097 if (failed(Result: computeLocalVars(memo, context))) {
1098 // Check if the local variables without an explicit representation have
1099 // zero coefficients everywhere.
1100 SmallVector<unsigned> noLocalRepVars;
1101 unsigned numDimsSymbols = getNumDimAndSymbolVars();
1102 for (unsigned i = numDimsSymbols, e = getNumVars(); i < e; ++i) {
1103 if (!memo[i] && !isColZero(/*pos=*/i))
1104 noLocalRepVars.push_back(Elt: i - numDimsSymbols);
1105 }
1106 if (!noLocalRepVars.empty()) {
1107 LLVM_DEBUG({
1108 llvm::dbgs() << "local variables at position(s) "
1109 << llvm::interleaved(noLocalRepVars)
1110 << " do not have an explicit representation in:\n";
1111 this->dump();
1112 });
1113 return IntegerSet();
1114 }
1115 }
1116
1117 ArrayRef<AffineExpr> localExprs =
1118 ArrayRef<AffineExpr>(memo).take_back(N: getNumLocalVars());
1119
1120 // Construct the IntegerSet from the equalities/inequalities.
1121 unsigned numDims = getNumDimVars();
1122 unsigned numSyms = getNumSymbolVars();
1123
1124 SmallVector<bool, 16> eqFlags(getNumConstraints());
1125 std::fill(first: eqFlags.begin(), last: eqFlags.begin() + getNumEqualities(), value: true);
1126 std::fill(first: eqFlags.begin() + getNumEqualities(), last: eqFlags.end(), value: false);
1127
1128 SmallVector<AffineExpr, 8> exprs;
1129 exprs.reserve(N: getNumConstraints());
1130
1131 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1132 exprs.push_back(Elt: getAffineExprFromFlatForm(flatExprs: getEquality64(idx: i), numDims,
1133 numSymbols: numSyms, localExprs, context));
1134 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
1135 exprs.push_back(Elt: getAffineExprFromFlatForm(flatExprs: getInequality64(idx: i), numDims,
1136 numSymbols: numSyms, localExprs, context));
1137 return IntegerSet::get(dimCount: numDims, symbolCount: numSyms, constraints: exprs, eqFlags);
1138}
1139
1140//===----------------------------------------------------------------------===//
1141// FlatLinearValueConstraints
1142//===----------------------------------------------------------------------===//
1143
1144// Construct from an IntegerSet.
1145FlatLinearValueConstraints::FlatLinearValueConstraints(IntegerSet set,
1146 ValueRange operands)
1147 : FlatLinearConstraints(set.getNumInequalities(), set.getNumEqualities(),
1148 set.getNumDims() + set.getNumSymbols() + 1,
1149 set.getNumDims(), set.getNumSymbols(),
1150 /*numLocals=*/0) {
1151 assert((operands.empty() || set.getNumInputs() == operands.size()) &&
1152 "operand count mismatch");
1153 // Set the values for the non-local variables.
1154 for (unsigned i = 0, e = operands.size(); i < e; ++i)
1155 setValue(pos: i, val: operands[i]);
1156
1157 // Flatten expressions and add them to the constraint system.
1158 std::vector<SmallVector<int64_t, 8>> flatExprs;
1159 FlatLinearConstraints localVarCst;
1160 if (failed(Result: getFlattenedAffineExprs(set, flattenedExprs: &flatExprs, localVarCst: &localVarCst))) {
1161 assert(false && "flattening unimplemented for semi-affine integer sets");
1162 return;
1163 }
1164 assert(flatExprs.size() == set.getNumConstraints());
1165 insertVar(kind: VarKind::Local, pos: getNumVarKind(kind: VarKind::Local),
1166 /*num=*/localVarCst.getNumLocalVars());
1167
1168 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
1169 const auto &flatExpr = flatExprs[i];
1170 assert(flatExpr.size() == getNumCols());
1171 if (set.getEqFlags()[i]) {
1172 addEquality(eq: flatExpr);
1173 } else {
1174 addInequality(inEq: flatExpr);
1175 }
1176 }
1177 // Add the other constraints involving local vars from flattening.
1178 append(other: localVarCst);
1179}
1180
1181unsigned FlatLinearValueConstraints::appendDimVar(ValueRange vals) {
1182 unsigned pos = getNumDimVars();
1183 return insertVar(kind: VarKind::SetDim, pos, vals);
1184}
1185
1186unsigned FlatLinearValueConstraints::appendSymbolVar(ValueRange vals) {
1187 unsigned pos = getNumSymbolVars();
1188 return insertVar(kind: VarKind::Symbol, pos, vals);
1189}
1190
1191unsigned FlatLinearValueConstraints::insertDimVar(unsigned pos,
1192 ValueRange vals) {
1193 return insertVar(kind: VarKind::SetDim, pos, vals);
1194}
1195
1196unsigned FlatLinearValueConstraints::insertSymbolVar(unsigned pos,
1197 ValueRange vals) {
1198 return insertVar(kind: VarKind::Symbol, pos, vals);
1199}
1200
1201unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
1202 unsigned num) {
1203 unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
1204
1205 return absolutePos;
1206}
1207
1208unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
1209 ValueRange vals) {
1210 assert(!vals.empty() && "expected ValueRange with Values.");
1211 assert(kind != VarKind::Local &&
1212 "values cannot be attached to local variables.");
1213 unsigned num = vals.size();
1214 unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
1215
1216 // If a Value is provided, insert it; otherwise use std::nullopt.
1217 for (unsigned i = 0, e = vals.size(); i < e; ++i)
1218 if (vals[i])
1219 setValue(pos: absolutePos + i, val: vals[i]);
1220
1221 return absolutePos;
1222}
1223
1224/// Checks if two constraint systems are in the same space, i.e., if they are
1225/// associated with the same set of variables, appearing in the same order.
1226static bool areVarsAligned(const FlatLinearValueConstraints &a,
1227 const FlatLinearValueConstraints &b) {
1228 if (a.getNumDomainVars() != b.getNumDomainVars() ||
1229 a.getNumRangeVars() != b.getNumRangeVars() ||
1230 a.getNumSymbolVars() != b.getNumSymbolVars())
1231 return false;
1232 SmallVector<std::optional<Value>> aMaybeValues = a.getMaybeValues(),
1233 bMaybeValues = b.getMaybeValues();
1234 return std::equal(first1: aMaybeValues.begin(), last1: aMaybeValues.end(),
1235 first2: bMaybeValues.begin(), last2: bMaybeValues.end());
1236}
1237
1238/// Calls areVarsAligned to check if two constraint systems have the same set
1239/// of variables in the same order.
1240bool FlatLinearValueConstraints::areVarsAlignedWithOther(
1241 const FlatLinearConstraints &other) {
1242 return areVarsAligned(a: *this, b: other);
1243}
1244
1245/// Checks if the SSA values associated with `cst`'s variables in range
1246/// [start, end) are unique.
1247static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
1248 const FlatLinearValueConstraints &cst, unsigned start, unsigned end) {
1249
1250 assert(start <= cst.getNumDimAndSymbolVars() &&
1251 "Start position out of bounds");
1252 assert(end <= cst.getNumDimAndSymbolVars() && "End position out of bounds");
1253
1254 if (start >= end)
1255 return true;
1256
1257 SmallPtrSet<Value, 8> uniqueVars;
1258 SmallVector<std::optional<Value>, 8> maybeValuesAll = cst.getMaybeValues();
1259 ArrayRef<std::optional<Value>> maybeValues = {maybeValuesAll.data() + start,
1260 maybeValuesAll.data() + end};
1261
1262 for (std::optional<Value> val : maybeValues)
1263 if (val && !uniqueVars.insert(Ptr: *val).second)
1264 return false;
1265
1266 return true;
1267}
1268
1269/// Checks if the SSA values associated with `cst`'s variables are unique.
1270static bool LLVM_ATTRIBUTE_UNUSED
1271areVarsUnique(const FlatLinearValueConstraints &cst) {
1272 return areVarsUnique(cst, start: 0, end: cst.getNumDimAndSymbolVars());
1273}
1274
1275/// Checks if the SSA values associated with `cst`'s variables of kind `kind`
1276/// are unique.
1277static bool LLVM_ATTRIBUTE_UNUSED
1278areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) {
1279
1280 if (kind == VarKind::SetDim)
1281 return areVarsUnique(cst, start: 0, end: cst.getNumDimVars());
1282 if (kind == VarKind::Symbol)
1283 return areVarsUnique(cst, start: cst.getNumDimVars(),
1284 end: cst.getNumDimAndSymbolVars());
1285 llvm_unreachable("Unexpected VarKind");
1286}
1287
1288/// Merge and align the variables of A and B starting at 'offset', so that
1289/// both constraint systems get the union of the contained variables that is
1290/// dimension-wise and symbol-wise unique; both constraint systems are updated
1291/// so that they have the union of all variables, with A's original
1292/// variables appearing first followed by any of B's variables that didn't
1293/// appear in A. Local variables in B that have the same division
1294/// representation as local variables in A are merged into one. We allow A
1295/// and B to have non-unique values for their variables; in such cases, they are
1296/// still aligned with the variables appearing first aligned with those
1297/// appearing first in the other system from left to right.
1298// E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
1299// Output: both A, B have (%i, %j, %k) [%M, %N, %P]
1300static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a,
1301 FlatLinearValueConstraints *b) {
1302 assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars());
1303
1304 assert(llvm::all_of(
1305 llvm::drop_begin(a->getMaybeValues(), offset),
1306 [](const std::optional<Value> &var) { return var.has_value(); }));
1307
1308 assert(llvm::all_of(
1309 llvm::drop_begin(b->getMaybeValues(), offset),
1310 [](const std::optional<Value> &var) { return var.has_value(); }));
1311
1312 SmallVector<Value, 4> aDimValues;
1313 a->getValues(start: offset, end: a->getNumDimVars(), values: &aDimValues);
1314
1315 {
1316 // Merge dims from A into B.
1317 unsigned d = offset;
1318 for (Value aDimValue : aDimValues) {
1319 unsigned loc;
1320 // Find from the position `d` since we'd like to also consider the
1321 // possibility of multiple variables with the same `Value`. We align with
1322 // the next appearing one.
1323 if (b->findVar(val: aDimValue, pos: &loc, offset: d)) {
1324 assert(loc >= offset && "A's dim appears in B's aligned range");
1325 assert(loc < b->getNumDimVars() &&
1326 "A's dim appears in B's non-dim position");
1327 b->swapVar(posA: d, posB: loc);
1328 } else {
1329 b->insertDimVar(pos: d, vals: aDimValue);
1330 }
1331 d++;
1332 }
1333 // Dimensions that are in B, but not in A, are added at the end.
1334 for (unsigned t = a->getNumDimVars(), e = b->getNumDimVars(); t < e; t++) {
1335 a->appendDimVar(vals: b->getValue(pos: t));
1336 }
1337 assert(a->getNumDimVars() == b->getNumDimVars() &&
1338 "expected same number of dims");
1339 }
1340
1341 // Merge and align symbols of A and B
1342 a->mergeSymbolVars(other&: *b);
1343 // Merge and align locals of A and B
1344 a->mergeLocalVars(other&: *b);
1345
1346 assert(areVarsAligned(*a, *b) && "IDs expected to be aligned");
1347}
1348
1349// Call 'mergeAndAlignVars' to align constraint systems of 'this' and 'other'.
1350void FlatLinearValueConstraints::mergeAndAlignVarsWithOther(
1351 unsigned offset, FlatLinearValueConstraints *other) {
1352 mergeAndAlignVars(offset, a: this, b: other);
1353}
1354
1355/// Merge and align symbols of `this` and `other` such that both get union of
1356/// of symbols. Existing symbols need not be unique; they will be aligned from
1357/// left to right with duplicates aligned in the same order. Symbols with Value
1358/// as `None` are considered to be inequal to all other symbols.
1359void FlatLinearValueConstraints::mergeSymbolVars(
1360 FlatLinearValueConstraints &other) {
1361
1362 SmallVector<Value, 4> aSymValues;
1363 getValues(start: getNumDimVars(), end: getNumDimAndSymbolVars(), values: &aSymValues);
1364
1365 // Merge symbols: merge symbols into `other` first from `this`.
1366 unsigned s = other.getNumDimVars();
1367 for (Value aSymValue : aSymValues) {
1368 unsigned loc;
1369 // If the var is a symbol in `other`, then align it, otherwise assume that
1370 // it is a new symbol. Search in `other` starting at position `s` since the
1371 // left of it is aligned.
1372 if (other.findVar(val: aSymValue, pos: &loc, offset: s) && loc >= other.getNumDimVars() &&
1373 loc < other.getNumDimAndSymbolVars())
1374 other.swapVar(posA: s, posB: loc);
1375 else
1376 other.insertSymbolVar(pos: s - other.getNumDimVars(), vals: aSymValue);
1377 s++;
1378 }
1379
1380 // Symbols that are in other, but not in this, are added at the end.
1381 for (unsigned t = other.getNumDimVars() + getNumSymbolVars(),
1382 e = other.getNumDimAndSymbolVars();
1383 t < e; t++)
1384 insertSymbolVar(pos: getNumSymbolVars(), vals: other.getValue(pos: t));
1385
1386 assert(getNumSymbolVars() == other.getNumSymbolVars() &&
1387 "expected same number of symbols");
1388}
1389
1390void FlatLinearValueConstraints::removeVarRange(VarKind kind, unsigned varStart,
1391 unsigned varLimit) {
1392 IntegerPolyhedron::removeVarRange(kind, varStart, varLimit);
1393}
1394
1395AffineMap
1396FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
1397 ValueRange operands) const {
1398 assert(map.getNumInputs() == operands.size() && "number of inputs mismatch");
1399
1400 SmallVector<Value> dims, syms;
1401#ifndef NDEBUG
1402 SmallVector<Value> newSyms;
1403 SmallVector<Value> *newSymsPtr = &newSyms;
1404#else
1405 SmallVector<Value> *newSymsPtr = nullptr;
1406#endif // NDEBUG
1407
1408 dims.reserve(N: getNumDimVars());
1409 syms.reserve(N: getNumSymbolVars());
1410 for (unsigned i = 0, e = getNumVarKind(kind: VarKind::SetDim); i < e; ++i) {
1411 Identifier id = space.getId(kind: VarKind::SetDim, pos: i);
1412 dims.push_back(Elt: id.hasValue() ? Value(id.getValue<Value>()) : Value());
1413 }
1414 for (unsigned i = 0, e = getNumVarKind(kind: VarKind::Symbol); i < e; ++i) {
1415 Identifier id = space.getId(kind: VarKind::Symbol, pos: i);
1416 syms.push_back(Elt: id.hasValue() ? Value(id.getValue<Value>()) : Value());
1417 }
1418
1419 AffineMap alignedMap =
1420 alignAffineMapWithValues(map, operands, dims, syms, newSyms: newSymsPtr);
1421 // All symbols are already part of this FlatAffineValueConstraints.
1422 assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
1423 assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
1424 "unexpected new/missing symbols");
1425 return alignedMap;
1426}
1427
1428bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
1429 unsigned offset) const {
1430 SmallVector<std::optional<Value>> maybeValues = getMaybeValues();
1431 for (unsigned i = offset, e = maybeValues.size(); i < e; ++i)
1432 if (maybeValues[i] && maybeValues[i].value() == val) {
1433 *pos = i;
1434 return true;
1435 }
1436 return false;
1437}
1438
1439bool FlatLinearValueConstraints::containsVar(Value val) const {
1440 unsigned pos;
1441 return findVar(val, pos: &pos, offset: 0);
1442}
1443
1444void FlatLinearValueConstraints::addBound(BoundType type, Value val,
1445 int64_t value) {
1446 unsigned pos;
1447 if (!findVar(val, pos: &pos))
1448 // This is a pre-condition for this method.
1449 assert(0 && "var not found");
1450 addBound(type, pos, value);
1451}
1452
1453void FlatLinearConstraints::printSpace(raw_ostream &os) const {
1454 IntegerPolyhedron::printSpace(os);
1455 os << "(";
1456 for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++)
1457 os << "None\t";
1458 for (unsigned i = getVarKindOffset(kind: VarKind::Local),
1459 e = getVarKindEnd(kind: VarKind::Local);
1460 i < e; ++i)
1461 os << "Local\t";
1462 os << "const)\n";
1463}
1464
1465void FlatLinearValueConstraints::printSpace(raw_ostream &os) const {
1466 IntegerPolyhedron::printSpace(os);
1467 os << "(";
1468 for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) {
1469 if (hasValue(pos: i))
1470 os << "Value\t";
1471 else
1472 os << "None\t";
1473 }
1474 for (unsigned i = getVarKindOffset(kind: VarKind::Local),
1475 e = getVarKindEnd(kind: VarKind::Local);
1476 i < e; ++i)
1477 os << "Local\t";
1478 os << "const)\n";
1479}
1480
1481void FlatLinearValueConstraints::projectOut(Value val) {
1482 unsigned pos;
1483 bool ret = findVar(val, pos: &pos);
1484 assert(ret);
1485 (void)ret;
1486 fourierMotzkinEliminate(pos);
1487}
1488
1489LogicalResult FlatLinearValueConstraints::unionBoundingBox(
1490 const FlatLinearValueConstraints &otherCst) {
1491 assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch");
1492 SmallVector<std::optional<Value>> maybeValues = getMaybeValues(),
1493 otherMaybeValues =
1494 otherCst.getMaybeValues();
1495 assert(std::equal(maybeValues.begin(), maybeValues.begin() + getNumDimVars(),
1496 otherMaybeValues.begin(),
1497 otherMaybeValues.begin() + getNumDimVars()) &&
1498 "dim values mismatch");
1499 assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
1500 assert(getNumLocalVars() == 0 && "local vars not supported yet here");
1501
1502 // Align `other` to this.
1503 if (!areVarsAligned(a: *this, b: otherCst)) {
1504 FlatLinearValueConstraints otherCopy(otherCst);
1505 mergeAndAlignVars(/*offset=*/getNumDimVars(), a: this, b: &otherCopy);
1506 return IntegerPolyhedron::unionBoundingBox(other: otherCopy);
1507 }
1508
1509 return IntegerPolyhedron::unionBoundingBox(other: otherCst);
1510}
1511
1512//===----------------------------------------------------------------------===//
1513// Helper functions
1514//===----------------------------------------------------------------------===//
1515
1516AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
1517 ValueRange dims, ValueRange syms,
1518 SmallVector<Value> *newSyms) {
1519 assert(operands.size() == map.getNumInputs() &&
1520 "expected same number of operands and map inputs");
1521 MLIRContext *ctx = map.getContext();
1522 Builder builder(ctx);
1523 SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {});
1524 unsigned numSymbols = syms.size();
1525 SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {});
1526 if (newSyms) {
1527 newSyms->clear();
1528 newSyms->append(in_start: syms.begin(), in_end: syms.end());
1529 }
1530
1531 for (const auto &operand : llvm::enumerate(First&: operands)) {
1532 // Compute replacement dim/sym of operand.
1533 AffineExpr replacement;
1534 auto dimIt = llvm::find(Range&: dims, Val: operand.value());
1535 auto symIt = llvm::find(Range&: syms, Val: operand.value());
1536 if (dimIt != dims.end()) {
1537 replacement =
1538 builder.getAffineDimExpr(position: std::distance(first: dims.begin(), last: dimIt));
1539 } else if (symIt != syms.end()) {
1540 replacement =
1541 builder.getAffineSymbolExpr(position: std::distance(first: syms.begin(), last: symIt));
1542 } else {
1543 // This operand is neither a dimension nor a symbol. Add it as a new
1544 // symbol.
1545 replacement = builder.getAffineSymbolExpr(position: numSymbols++);
1546 if (newSyms)
1547 newSyms->push_back(Elt: operand.value());
1548 }
1549 // Add to corresponding replacements vector.
1550 if (operand.index() < map.getNumDims()) {
1551 dimReplacements[operand.index()] = replacement;
1552 } else {
1553 symReplacements[operand.index() - map.getNumDims()] = replacement;
1554 }
1555 }
1556
1557 return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
1558 numResultDims: dims.size(), numResultSyms: numSymbols);
1559}
1560
1561LogicalResult
1562mlir::getMultiAffineFunctionFromMap(AffineMap map,
1563 MultiAffineFunction &multiAff) {
1564 FlatLinearConstraints cst;
1565 std::vector<SmallVector<int64_t, 8>> flattenedExprs;
1566 LogicalResult result = getFlattenedAffineExprs(map, flattenedExprs: &flattenedExprs, localVarCst: &cst);
1567
1568 if (result.failed())
1569 return failure();
1570
1571 DivisionRepr divs = cst.getLocalReprs();
1572 assert(divs.hasAllReprs() &&
1573 "AffineMap cannot produce divs without local representation");
1574
1575 // TODO: We shouldn't have to do this conversion.
1576 Matrix<DynamicAPInt> mat(map.getNumResults(),
1577 map.getNumInputs() + divs.getNumDivs() + 1);
1578 for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i)
1579 for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j)
1580 mat(i, j) = flattenedExprs[i][j];
1581
1582 multiAff = MultiAffineFunction(
1583 PresburgerSpace::getRelationSpace(numDomain: map.getNumDims(), numRange: map.getNumResults(),
1584 numSymbols: map.getNumSymbols(), numLocals: divs.getNumDivs()),
1585 mat, divs);
1586
1587 return success();
1588}
1589

source code of mlir/lib/Analysis/FlatLinearValueConstraints.cpp