1//===- FlatLinearValueConstraints.cpp - Linear Constraint -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis//FlatLinearValueConstraints.h"
10
11#include "mlir/Analysis/Presburger/LinearTransform.h"
12#include "mlir/Analysis/Presburger/PresburgerSpace.h"
13#include "mlir/Analysis/Presburger/Simplex.h"
14#include "mlir/Analysis/Presburger/Utils.h"
15#include "mlir/IR/AffineExprVisitor.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/IntegerSet.h"
18#include "mlir/Support/LLVM.h"
19#include "mlir/Support/MathExtras.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/raw_ostream.h"
25#include <optional>
26
27#define DEBUG_TYPE "flat-value-constraints"
28
29using namespace mlir;
30using namespace presburger;
31
32//===----------------------------------------------------------------------===//
33// AffineExprFlattener
34//===----------------------------------------------------------------------===//
35
36namespace {
37
38// See comments for SimpleAffineExprFlattener.
39// An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
40// constraint information associated with mod's, floordiv's, and ceildiv's
41// in FlatLinearConstraints 'localVarCst'.
42struct AffineExprFlattener : public SimpleAffineExprFlattener {
43public:
44 // Constraints connecting newly introduced local variables (for mod's and
45 // div's) to existing (dimensional and symbolic) ones. These are always
46 // inequalities.
47 IntegerPolyhedron localVarCst;
48
49 AffineExprFlattener(unsigned nDims, unsigned nSymbols)
50 : SimpleAffineExprFlattener(nDims, nSymbols),
51 localVarCst(PresburgerSpace::getSetSpace(numDims: nDims, numSymbols: nSymbols)) {}
52
53private:
54 // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr).
55 // The local variable added is always a floordiv of a pure add/mul affine
56 // function of other variables, coefficients of which are specified in
57 // `dividend' and with respect to the positive constant `divisor'. localExpr
58 // is the simplified tree expression (AffineExpr) corresponding to the
59 // quantifier.
60 void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
61 AffineExpr localExpr) override {
62 SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
63 // Update localVarCst.
64 localVarCst.addLocalFloorDiv(dividend, divisor);
65 }
66};
67
68} // namespace
69
70// Flattens the expressions in map. Returns failure if 'expr' was unable to be
71// flattened. For example two specific cases:
72// 1. semi-affine expressions not handled yet.
73// 2. has poison expression (i.e., division by zero).
74static LogicalResult
75getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
76 unsigned numSymbols,
77 std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
78 FlatLinearConstraints *localVarCst) {
79 if (exprs.empty()) {
80 if (localVarCst)
81 *localVarCst = FlatLinearConstraints(numDims, numSymbols);
82 return success();
83 }
84
85 AffineExprFlattener flattener(numDims, numSymbols);
86 // Use the same flattener to simplify each expression successively. This way
87 // local variables / expressions are shared.
88 for (auto expr : exprs) {
89 if (!expr.isPureAffine())
90 return failure();
91 // has poison expression
92 auto flattenResult = flattener.walkPostOrder(expr);
93 if (failed(result: flattenResult))
94 return failure();
95 }
96
97 assert(flattener.operandExprStack.size() == exprs.size());
98 flattenedExprs->clear();
99 flattenedExprs->assign(first: flattener.operandExprStack.begin(),
100 last: flattener.operandExprStack.end());
101
102 if (localVarCst)
103 localVarCst->clearAndCopyFrom(other: flattener.localVarCst);
104
105 return success();
106}
107
108// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
109// be flattened (semi-affine expressions not handled yet).
110LogicalResult
111mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
112 unsigned numSymbols,
113 SmallVectorImpl<int64_t> *flattenedExpr,
114 FlatLinearConstraints *localVarCst) {
115 std::vector<SmallVector<int64_t, 8>> flattenedExprs;
116 LogicalResult ret = ::getFlattenedAffineExprs(exprs: {expr}, numDims, numSymbols,
117 flattenedExprs: &flattenedExprs, localVarCst);
118 *flattenedExpr = flattenedExprs[0];
119 return ret;
120}
121
122/// Flattens the expressions in map. Returns failure if 'expr' was unable to be
123/// flattened (i.e., semi-affine expressions not handled yet).
124LogicalResult mlir::getFlattenedAffineExprs(
125 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
126 FlatLinearConstraints *localVarCst) {
127 if (map.getNumResults() == 0) {
128 if (localVarCst)
129 *localVarCst =
130 FlatLinearConstraints(map.getNumDims(), map.getNumSymbols());
131 return success();
132 }
133 return ::getFlattenedAffineExprs(exprs: map.getResults(), numDims: map.getNumDims(),
134 numSymbols: map.getNumSymbols(), flattenedExprs,
135 localVarCst);
136}
137
138LogicalResult mlir::getFlattenedAffineExprs(
139 IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
140 FlatLinearConstraints *localVarCst) {
141 if (set.getNumConstraints() == 0) {
142 if (localVarCst)
143 *localVarCst =
144 FlatLinearConstraints(set.getNumDims(), set.getNumSymbols());
145 return success();
146 }
147 return ::getFlattenedAffineExprs(exprs: set.getConstraints(), numDims: set.getNumDims(),
148 numSymbols: set.getNumSymbols(), flattenedExprs,
149 localVarCst);
150}
151
152//===----------------------------------------------------------------------===//
153// FlatLinearConstraints
154//===----------------------------------------------------------------------===//
155
156// Similar to `composeMap` except that no Values need be associated with the
157// constraint system nor are they looked at -- the dimensions and symbols of
158// `other` are expected to correspond 1:1 to `this` system.
159LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) {
160 assert(other.getNumDims() == getNumDimVars() && "dim mismatch");
161 assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
162
163 std::vector<SmallVector<int64_t, 8>> flatExprs;
164 if (failed(result: flattenAlignedMapAndMergeLocals(map: other, flattenedExprs: &flatExprs)))
165 return failure();
166 assert(flatExprs.size() == other.getNumResults());
167
168 // Add dimensions corresponding to the map's results.
169 insertDimVar(/*pos=*/0, /*num=*/other.getNumResults());
170
171 // We add one equality for each result connecting the result dim of the map to
172 // the other variables.
173 // E.g.: if the expression is 16*i0 + i1, and this is the r^th
174 // iteration/result of the value map, we are adding the equality:
175 // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
176 // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
177 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
178 const auto &flatExpr = flatExprs[r];
179 assert(flatExpr.size() >= other.getNumInputs() + 1);
180
181 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
182 // Set the coefficient for this result to one.
183 eqToAdd[r] = 1;
184
185 // Dims and symbols.
186 for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
187 // Negate `eq[r]` since the newly added dimension will be set to this one.
188 eqToAdd[e + i] = -flatExpr[i];
189 }
190 // Local columns of `eq` are at the beginning.
191 unsigned j = getNumDimVars() + getNumSymbolVars();
192 unsigned end = flatExpr.size() - 1;
193 for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
194 eqToAdd[j] = -flatExpr[i];
195 }
196
197 // Constant term.
198 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
199
200 // Add the equality connecting the result of the map to this constraint set.
201 addEquality(eq: eqToAdd);
202 }
203
204 return success();
205}
206
207// Determine whether the variable at 'pos' (say var_r) can be expressed as
208// modulo of another known variable (say var_n) w.r.t a constant. For example,
209// if the following constraints hold true:
210// ```
211// 0 <= var_r <= divisor - 1
212// var_n - (divisor * q_expr) = var_r
213// ```
214// where `var_n` is a known variable (called dividend), and `q_expr` is an
215// `AffineExpr` (called the quotient expression), `var_r` can be written as:
216//
217// `var_r = var_n mod divisor`.
218//
219// Additionally, in a special case of the above constaints where `q_expr` is an
220// variable itself that is not yet known (say `var_q`), it can be written as a
221// floordiv in the following way:
222//
223// `var_q = var_n floordiv divisor`.
224//
225// First 'num' dimensional variables starting at 'offset' are
226// derived/to-be-derived in terms of the remaining variables. The remaining
227// variables are assigned trivial affine expressions in `memo`. For example,
228// memo is initilized as follows for a `cst` with 5 dims, when offset=2, num=2:
229// memo ==> d0 d1 . . d2 ...
230// cst ==> c0 c1 c2 c3 c4 ...
231//
232// Returns true if the above mod or floordiv are detected, updating 'memo' with
233// these new expressions. Returns false otherwise.
234static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos,
235 unsigned offset, unsigned num, int64_t lbConst,
236 int64_t ubConst, MLIRContext *context,
237 SmallVectorImpl<AffineExpr> &memo) {
238 assert(pos < cst.getNumVars() && "invalid position");
239
240 // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can
241 // be determined.
242 if (lbConst != 0 || ubConst < 1)
243 return false;
244 int64_t divisor = ubConst + 1;
245
246 // Check for the aforementioned conditions in each equality.
247 for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
248 curEquality < numEqualities; curEquality++) {
249 int64_t coefficientAtPos = cst.atEq64(i: curEquality, j: pos);
250 // If current equality does not involve `var_r`, continue to the next
251 // equality.
252 if (coefficientAtPos == 0)
253 continue;
254
255 // Constant term should be 0 in this equality.
256 if (cst.atEq64(i: curEquality, j: cst.getNumCols() - 1) != 0)
257 continue;
258
259 // Traverse through the equality and construct the dividend expression
260 // `dividendExpr`, to contain all the variables which are known and are
261 // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
262 // `dividendExpr` gets simplified into a single variable `var_n` discussed
263 // above.
264 auto dividendExpr = getAffineConstantExpr(constant: 0, context);
265
266 // Track the terms that go into quotient expression, later used to detect
267 // additional floordiv.
268 unsigned quotientCount = 0;
269 int quotientPosition = -1;
270 int quotientSign = 1;
271
272 // Consider each term in the current equality.
273 unsigned curVar, e;
274 for (curVar = 0, e = cst.getNumDimAndSymbolVars(); curVar < e; ++curVar) {
275 // Ignore var_r.
276 if (curVar == pos)
277 continue;
278 int64_t coefficientOfCurVar = cst.atEq64(i: curEquality, j: curVar);
279 // Ignore vars that do not contribute to the current equality.
280 if (coefficientOfCurVar == 0)
281 continue;
282 // Check if the current var goes into the quotient expression.
283 if (coefficientOfCurVar % (divisor * coefficientAtPos) == 0) {
284 quotientCount++;
285 quotientPosition = curVar;
286 quotientSign = (coefficientOfCurVar * coefficientAtPos) > 0 ? 1 : -1;
287 continue;
288 }
289 // Variables that are part of dividendExpr should be known.
290 if (!memo[curVar])
291 break;
292 // Append the current variable to the dividend expression.
293 dividendExpr = dividendExpr + memo[curVar] * coefficientOfCurVar;
294 }
295
296 // Can't construct expression as it depends on a yet uncomputed var.
297 if (curVar < e)
298 continue;
299
300 // Express `var_r` in terms of the other vars collected so far.
301 if (coefficientAtPos > 0)
302 dividendExpr = (-dividendExpr).floorDiv(v: coefficientAtPos);
303 else
304 dividendExpr = dividendExpr.floorDiv(v: -coefficientAtPos);
305
306 // Simplify the expression.
307 dividendExpr = simplifyAffineExpr(expr: dividendExpr, numDims: cst.getNumDimVars(),
308 numSymbols: cst.getNumSymbolVars());
309 // Only if the final dividend expression is just a single var (which we call
310 // `var_n`), we can proceed.
311 // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
312 // to dims themselves.
313 auto dimExpr = dyn_cast<AffineDimExpr>(Val&: dividendExpr);
314 if (!dimExpr)
315 continue;
316
317 // Express `var_r` as `var_n % divisor` and store the expression in `memo`.
318 if (quotientCount >= 1) {
319 // Find the column corresponding to `dimExpr`. `num` columns starting at
320 // `offset` correspond to previously unknown variables. The column
321 // corresponding to the trivially known `dimExpr` can be on either side
322 // of these.
323 unsigned dimExprPos = dimExpr.getPosition();
324 unsigned dimExprCol = dimExprPos < offset ? dimExprPos : dimExprPos + num;
325 auto ub = cst.getConstantBound64(type: BoundType::UB, pos: dimExprCol);
326 // If `var_n` has an upperbound that is less than the divisor, mod can be
327 // eliminated altogether.
328 if (ub && *ub < divisor)
329 memo[pos] = dimExpr;
330 else
331 memo[pos] = dimExpr % divisor;
332 // If a unique quotient `var_q` was seen, it can be expressed as
333 // `var_n floordiv divisor`.
334 if (quotientCount == 1 && !memo[quotientPosition])
335 memo[quotientPosition] = dimExpr.floorDiv(v: divisor) * quotientSign;
336
337 return true;
338 }
339 }
340 return false;
341}
342
343/// Check if the pos^th variable can be expressed as a floordiv of an affine
344/// function of other variables (where the divisor is a positive constant)
345/// given the initial set of expressions in `exprs`. If it can be, the
346/// corresponding position in `exprs` is set as the detected affine expr. For
347/// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can
348/// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
349/// <= i <= 32q + 31 => q = i floordiv 32.
350static bool detectAsFloorDiv(const FlatLinearConstraints &cst, unsigned pos,
351 MLIRContext *context,
352 SmallVectorImpl<AffineExpr> &exprs) {
353 assert(pos < cst.getNumVars() && "invalid position");
354
355 // Get upper-lower bound pair for this variable.
356 SmallVector<bool, 8> foundRepr(cst.getNumVars(), false);
357 for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i)
358 if (exprs[i])
359 foundRepr[i] = true;
360
361 SmallVector<int64_t, 8> dividend(cst.getNumCols());
362 unsigned divisor;
363 auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor);
364
365 // No upper-lower bound pair found for this var.
366 if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality)
367 return false;
368
369 // Construct the dividend expression.
370 auto dividendExpr = getAffineConstantExpr(constant: dividend.back(), context);
371 for (unsigned c = 0, f = cst.getNumVars(); c < f; c++)
372 if (dividend[c] != 0)
373 dividendExpr = dividendExpr + dividend[c] * exprs[c];
374
375 // Successfully detected the floordiv.
376 exprs[pos] = dividendExpr.floorDiv(v: divisor);
377 return true;
378}
379
380std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound(
381 unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
382 ArrayRef<AffineExpr> localExprs, MLIRContext *context,
383 bool closedUB) const {
384 assert(pos + offset < getNumDimVars() && "invalid dim start pos");
385 assert(symStartPos >= (pos + offset) && "invalid sym start pos");
386 assert(getNumLocalVars() == localExprs.size() &&
387 "incorrect local exprs count");
388
389 SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
390 getLowerAndUpperBoundIndices(pos: pos + offset, lbIndices: &lbIndices, ubIndices: &ubIndices, eqIndices: &eqIndices,
391 offset, num);
392
393 /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
394 auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
395 b.clear();
396 for (unsigned i = 0, e = a.size(); i < e; ++i) {
397 if (i < offset || i >= offset + num)
398 b.push_back(Elt: a[i]);
399 }
400 };
401
402 SmallVector<int64_t, 8> lb, ub;
403 SmallVector<AffineExpr, 4> lbExprs;
404 unsigned dimCount = symStartPos - num;
405 unsigned symCount = getNumDimAndSymbolVars() - symStartPos;
406 lbExprs.reserve(N: lbIndices.size() + eqIndices.size());
407 // Lower bound expressions.
408 for (auto idx : lbIndices) {
409 auto ineq = getInequality64(idx);
410 // Extract the lower bound (in terms of other coeff's + const), i.e., if
411 // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
412 // - 1.
413 addCoeffs(ineq, lb);
414 std::transform(first: lb.begin(), last: lb.end(), result: lb.begin(), unary_op: std::negate<int64_t>());
415 auto expr =
416 getAffineExprFromFlatForm(flatExprs: lb, numDims: dimCount, numSymbols: symCount, localExprs, context);
417 // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
418 int64_t divisor = std::abs(i: ineq[pos + offset]);
419 expr = (expr + divisor - 1).floorDiv(v: divisor);
420 lbExprs.push_back(Elt: expr);
421 }
422
423 SmallVector<AffineExpr, 4> ubExprs;
424 ubExprs.reserve(N: ubIndices.size() + eqIndices.size());
425 // Upper bound expressions.
426 for (auto idx : ubIndices) {
427 auto ineq = getInequality64(idx);
428 // Extract the upper bound (in terms of other coeff's + const).
429 addCoeffs(ineq, ub);
430 auto expr =
431 getAffineExprFromFlatForm(flatExprs: ub, numDims: dimCount, numSymbols: symCount, localExprs, context);
432 expr = expr.floorDiv(v: std::abs(i: ineq[pos + offset]));
433 int64_t ubAdjustment = closedUB ? 0 : 1;
434 ubExprs.push_back(Elt: expr + ubAdjustment);
435 }
436
437 // Equalities. It's both a lower and a upper bound.
438 SmallVector<int64_t, 4> b;
439 for (auto idx : eqIndices) {
440 auto eq = getEquality64(idx);
441 addCoeffs(eq, b);
442 if (eq[pos + offset] > 0)
443 std::transform(first: b.begin(), last: b.end(), result: b.begin(), unary_op: std::negate<int64_t>());
444
445 // Extract the upper bound (in terms of other coeff's + const).
446 auto expr =
447 getAffineExprFromFlatForm(flatExprs: b, numDims: dimCount, numSymbols: symCount, localExprs, context);
448 expr = expr.floorDiv(v: std::abs(i: eq[pos + offset]));
449 // Upper bound is exclusive.
450 ubExprs.push_back(Elt: expr + 1);
451 // Lower bound.
452 expr =
453 getAffineExprFromFlatForm(flatExprs: b, numDims: dimCount, numSymbols: symCount, localExprs, context);
454 expr = expr.ceilDiv(v: std::abs(i: eq[pos + offset]));
455 lbExprs.push_back(Elt: expr);
456 }
457
458 auto lbMap = AffineMap::get(dimCount, symbolCount: symCount, results: lbExprs, context);
459 auto ubMap = AffineMap::get(dimCount, symbolCount: symCount, results: ubExprs, context);
460
461 return {lbMap, ubMap};
462}
463
464/// Computes the lower and upper bounds of the first 'num' dimensional
465/// variables (starting at 'offset') as affine maps of the remaining
466/// variables (dimensional and symbolic variables). Local variables are
467/// themselves explicitly computed as affine functions of other variables in
468/// this process if needed.
469void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num,
470 MLIRContext *context,
471 SmallVectorImpl<AffineMap> *lbMaps,
472 SmallVectorImpl<AffineMap> *ubMaps,
473 bool closedUB) {
474 assert(offset + num <= getNumDimVars() && "invalid range");
475
476 // Basic simplification.
477 normalizeConstraintsByGCD();
478
479 LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
480 << " variables\n");
481 LLVM_DEBUG(dump());
482
483 // Record computed/detected variables.
484 SmallVector<AffineExpr, 8> memo(getNumVars());
485 // Initialize dimensional and symbolic variables.
486 for (unsigned i = 0, e = getNumDimVars(); i < e; i++) {
487 if (i < offset)
488 memo[i] = getAffineDimExpr(position: i, context);
489 else if (i >= offset + num)
490 memo[i] = getAffineDimExpr(position: i - num, context);
491 }
492 for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++)
493 memo[i] = getAffineSymbolExpr(position: i - getNumDimVars(), context);
494
495 bool changed;
496 do {
497 changed = false;
498 // Identify yet unknown variables as constants or mod's / floordiv's of
499 // other variables if possible.
500 for (unsigned pos = 0; pos < getNumVars(); pos++) {
501 if (memo[pos])
502 continue;
503
504 auto lbConst = getConstantBound64(type: BoundType::LB, pos);
505 auto ubConst = getConstantBound64(type: BoundType::UB, pos);
506 if (lbConst.has_value() && ubConst.has_value()) {
507 // Detect equality to a constant.
508 if (*lbConst == *ubConst) {
509 memo[pos] = getAffineConstantExpr(constant: *lbConst, context);
510 changed = true;
511 continue;
512 }
513
514 // Detect a variable as modulo of another variable w.r.t a
515 // constant.
516 if (detectAsMod(cst: *this, pos, offset, num, lbConst: *lbConst, ubConst: *ubConst, context,
517 memo)) {
518 changed = true;
519 continue;
520 }
521 }
522
523 // Detect a variable as a floordiv of an affine function of other
524 // variables (divisor is a positive constant).
525 if (detectAsFloorDiv(cst: *this, pos, context, exprs&: memo)) {
526 changed = true;
527 continue;
528 }
529
530 // Detect a variable as an expression of other variables.
531 unsigned idx;
532 if (!findConstraintWithNonZeroAt(colIdx: pos, /*isEq=*/true, rowIdx: &idx)) {
533 continue;
534 }
535
536 // Build AffineExpr solving for variable 'pos' in terms of all others.
537 auto expr = getAffineConstantExpr(constant: 0, context);
538 unsigned j, e;
539 for (j = 0, e = getNumVars(); j < e; ++j) {
540 if (j == pos)
541 continue;
542 int64_t c = atEq64(i: idx, j);
543 if (c == 0)
544 continue;
545 // If any of the involved IDs hasn't been found yet, we can't proceed.
546 if (!memo[j])
547 break;
548 expr = expr + memo[j] * c;
549 }
550 if (j < e)
551 // Can't construct expression as it depends on a yet uncomputed
552 // variable.
553 continue;
554
555 // Add constant term to AffineExpr.
556 expr = expr + atEq64(i: idx, j: getNumVars());
557 int64_t vPos = atEq64(i: idx, j: pos);
558 assert(vPos != 0 && "expected non-zero here");
559 if (vPos > 0)
560 expr = (-expr).floorDiv(v: vPos);
561 else
562 // vPos < 0.
563 expr = expr.floorDiv(v: -vPos);
564 // Successfully constructed expression.
565 memo[pos] = expr;
566 changed = true;
567 }
568 // This loop is guaranteed to reach a fixed point - since once an
569 // variable's explicit form is computed (in memo[pos]), it's not updated
570 // again.
571 } while (changed);
572
573 int64_t ubAdjustment = closedUB ? 0 : 1;
574
575 // Set the lower and upper bound maps for all the variables that were
576 // computed as affine expressions of the rest as the "detected expr" and
577 // "detected expr + 1" respectively; set the undetected ones to null.
578 std::optional<FlatLinearConstraints> tmpClone;
579 for (unsigned pos = 0; pos < num; pos++) {
580 unsigned numMapDims = getNumDimVars() - num;
581 unsigned numMapSymbols = getNumSymbolVars();
582 AffineExpr expr = memo[pos + offset];
583 if (expr)
584 expr = simplifyAffineExpr(expr, numDims: numMapDims, numSymbols: numMapSymbols);
585
586 AffineMap &lbMap = (*lbMaps)[pos];
587 AffineMap &ubMap = (*ubMaps)[pos];
588
589 if (expr) {
590 lbMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols, result: expr);
591 ubMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols, result: expr + ubAdjustment);
592 } else {
593 // TODO: Whenever there are local variables in the dependence
594 // constraints, we'll conservatively over-approximate, since we don't
595 // always explicitly compute them above (in the while loop).
596 if (getNumLocalVars() == 0) {
597 // Work on a copy so that we don't update this constraint system.
598 if (!tmpClone) {
599 tmpClone.emplace(args: FlatLinearConstraints(*this));
600 // Removing redundant inequalities is necessary so that we don't get
601 // redundant loop bounds.
602 tmpClone->removeRedundantInequalities();
603 }
604 std::tie(args&: lbMap, args&: ubMap) = tmpClone->getLowerAndUpperBound(
605 pos, offset, num, symStartPos: getNumDimVars(), /*localExprs=*/{}, context,
606 closedUB);
607 }
608
609 // If the above fails, we'll just use the constant lower bound and the
610 // constant upper bound (if they exist) as the slice bounds.
611 // TODO: being conservative for the moment in cases that
612 // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
613 // fixed (b/126426796).
614 if (!lbMap || lbMap.getNumResults() > 1) {
615 LLVM_DEBUG(llvm::dbgs()
616 << "WARNING: Potentially over-approximating slice lb\n");
617 auto lbConst = getConstantBound64(type: BoundType::LB, pos: pos + offset);
618 if (lbConst.has_value()) {
619 lbMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols,
620 result: getAffineConstantExpr(constant: *lbConst, context));
621 }
622 }
623 if (!ubMap || ubMap.getNumResults() > 1) {
624 LLVM_DEBUG(llvm::dbgs()
625 << "WARNING: Potentially over-approximating slice ub\n");
626 auto ubConst = getConstantBound64(type: BoundType::UB, pos: pos + offset);
627 if (ubConst.has_value()) {
628 ubMap = AffineMap::get(
629 dimCount: numMapDims, symbolCount: numMapSymbols,
630 result: getAffineConstantExpr(constant: *ubConst + ubAdjustment, context));
631 }
632 }
633 }
634 LLVM_DEBUG(llvm::dbgs()
635 << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
636 LLVM_DEBUG(lbMap.dump(););
637 LLVM_DEBUG(llvm::dbgs()
638 << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
639 LLVM_DEBUG(ubMap.dump(););
640 }
641}
642
643LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
644 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) {
645 FlatLinearConstraints localCst;
646 if (failed(result: getFlattenedAffineExprs(map, flattenedExprs, localVarCst: &localCst))) {
647 LLVM_DEBUG(llvm::dbgs()
648 << "composition unimplemented for semi-affine maps\n");
649 return failure();
650 }
651
652 // Add localCst information.
653 if (localCst.getNumLocalVars() > 0) {
654 unsigned numLocalVars = getNumLocalVars();
655 // Insert local dims of localCst at the beginning.
656 insertLocalVar(/*pos=*/0, /*num=*/localCst.getNumLocalVars());
657 // Insert local dims of `this` at the end of localCst.
658 localCst.appendLocalVar(/*num=*/numLocalVars);
659 // Dimensions of localCst and this constraint set match. Append localCst to
660 // this constraint set.
661 append(other: localCst);
662 }
663
664 return success();
665}
666
667LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos,
668 AffineMap boundMap,
669 bool isClosedBound) {
670 assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch");
671 assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
672 assert(pos < getNumDimAndSymbolVars() && "invalid position");
673 assert((type != BoundType::EQ || isClosedBound) &&
674 "EQ bound must be closed.");
675
676 // Equality follows the logic of lower bound except that we add an equality
677 // instead of an inequality.
678 assert((type != BoundType::EQ || boundMap.getNumResults() == 1) &&
679 "single result expected");
680 bool lower = type == BoundType::LB || type == BoundType::EQ;
681
682 std::vector<SmallVector<int64_t, 8>> flatExprs;
683 if (failed(result: flattenAlignedMapAndMergeLocals(map: boundMap, flattenedExprs: &flatExprs)))
684 return failure();
685 assert(flatExprs.size() == boundMap.getNumResults());
686
687 // Add one (in)equality for each result.
688 for (const auto &flatExpr : flatExprs) {
689 SmallVector<int64_t> ineq(getNumCols(), 0);
690 // Dims and symbols.
691 for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
692 ineq[j] = lower ? -flatExpr[j] : flatExpr[j];
693 }
694 // Invalid bound: pos appears in `boundMap`.
695 // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or
696 // its callers to prevent invalid bounds from being added.
697 if (ineq[pos] != 0)
698 continue;
699 ineq[pos] = lower ? 1 : -1;
700 // Local columns of `ineq` are at the beginning.
701 unsigned j = getNumDimVars() + getNumSymbolVars();
702 unsigned end = flatExpr.size() - 1;
703 for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
704 ineq[j] = lower ? -flatExpr[i] : flatExpr[i];
705 }
706 // Make the bound closed in if flatExpr is open. The inequality is always
707 // created in the upper bound form, so the adjustment is -1.
708 int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1;
709 // Constant term.
710 ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1]
711 : flatExpr[flatExpr.size() - 1]) +
712 boundAdjustment;
713 type == BoundType::EQ ? addEquality(eq: ineq) : addInequality(inEq: ineq);
714 }
715
716 return success();
717}
718
719LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos,
720 AffineMap boundMap) {
721 return addBound(type, pos, boundMap, /*isClosedBound=*/type != BoundType::UB);
722}
723
724/// Compute an explicit representation for local vars. For all systems coming
725/// from MLIR integer sets, maps, or expressions where local vars were
726/// introduced to model floordivs and mods, this always succeeds.
727LogicalResult
728FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo,
729 MLIRContext *context) const {
730 unsigned numDims = getNumDimVars();
731 unsigned numSyms = getNumSymbolVars();
732
733 // Initialize dimensional and symbolic variables.
734 for (unsigned i = 0; i < numDims; i++)
735 memo[i] = getAffineDimExpr(position: i, context);
736 for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
737 memo[i] = getAffineSymbolExpr(position: i - numDims, context);
738
739 bool changed;
740 do {
741 // Each time `changed` is true at the end of this iteration, one or more
742 // local vars would have been detected as floordivs and set in memo; so the
743 // number of null entries in memo[...] strictly reduces; so this converges.
744 changed = false;
745 for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i)
746 if (!memo[numDims + numSyms + i] &&
747 detectAsFloorDiv(cst: *this, /*pos=*/numDims + numSyms + i, context, exprs&: memo))
748 changed = true;
749 } while (changed);
750
751 ArrayRef<AffineExpr> localExprs =
752 ArrayRef<AffineExpr>(memo).take_back(N: getNumLocalVars());
753 return success(
754 isSuccess: llvm::all_of(Range&: localExprs, P: [](AffineExpr expr) { return expr; }));
755}
756
757IntegerSet FlatLinearConstraints::getAsIntegerSet(MLIRContext *context) const {
758 if (getNumConstraints() == 0)
759 // Return universal set (always true): 0 == 0.
760 return IntegerSet::get(dimCount: getNumDimVars(), symbolCount: getNumSymbolVars(),
761 constraints: getAffineConstantExpr(/*constant=*/0, context),
762 /*eqFlags=*/true);
763
764 // Construct local references.
765 SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
766
767 if (failed(result: computeLocalVars(memo, context))) {
768 // Check if the local variables without an explicit representation have
769 // zero coefficients everywhere.
770 SmallVector<unsigned> noLocalRepVars;
771 unsigned numDimsSymbols = getNumDimAndSymbolVars();
772 for (unsigned i = numDimsSymbols, e = getNumVars(); i < e; ++i) {
773 if (!memo[i] && !isColZero(/*pos=*/i))
774 noLocalRepVars.push_back(Elt: i - numDimsSymbols);
775 }
776 if (!noLocalRepVars.empty()) {
777 LLVM_DEBUG({
778 llvm::dbgs() << "local variables at position(s) ";
779 llvm::interleaveComma(noLocalRepVars, llvm::dbgs());
780 llvm::dbgs() << " do not have an explicit representation in:\n";
781 this->dump();
782 });
783 return IntegerSet();
784 }
785 }
786
787 ArrayRef<AffineExpr> localExprs =
788 ArrayRef<AffineExpr>(memo).take_back(N: getNumLocalVars());
789
790 // Construct the IntegerSet from the equalities/inequalities.
791 unsigned numDims = getNumDimVars();
792 unsigned numSyms = getNumSymbolVars();
793
794 SmallVector<bool, 16> eqFlags(getNumConstraints());
795 std::fill(first: eqFlags.begin(), last: eqFlags.begin() + getNumEqualities(), value: true);
796 std::fill(first: eqFlags.begin() + getNumEqualities(), last: eqFlags.end(), value: false);
797
798 SmallVector<AffineExpr, 8> exprs;
799 exprs.reserve(N: getNumConstraints());
800
801 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
802 exprs.push_back(Elt: getAffineExprFromFlatForm(flatExprs: getEquality64(idx: i), numDims,
803 numSymbols: numSyms, localExprs, context));
804 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
805 exprs.push_back(Elt: getAffineExprFromFlatForm(flatExprs: getInequality64(idx: i), numDims,
806 numSymbols: numSyms, localExprs, context));
807 return IntegerSet::get(dimCount: numDims, symbolCount: numSyms, constraints: exprs, eqFlags);
808}
809
810//===----------------------------------------------------------------------===//
811// FlatLinearValueConstraints
812//===----------------------------------------------------------------------===//
813
814// Construct from an IntegerSet.
815FlatLinearValueConstraints::FlatLinearValueConstraints(IntegerSet set,
816 ValueRange operands)
817 : FlatLinearConstraints(set.getNumInequalities(), set.getNumEqualities(),
818 set.getNumDims() + set.getNumSymbols() + 1,
819 set.getNumDims(), set.getNumSymbols(),
820 /*numLocals=*/0) {
821 assert(operands.empty() ||
822 set.getNumInputs() == operands.size() && "operand count mismatch");
823 // Set the values for the non-local variables.
824 for (unsigned i = 0, e = operands.size(); i < e; ++i)
825 setValue(pos: i, val: operands[i]);
826
827 // Flatten expressions and add them to the constraint system.
828 std::vector<SmallVector<int64_t, 8>> flatExprs;
829 FlatLinearConstraints localVarCst;
830 if (failed(result: getFlattenedAffineExprs(set, flattenedExprs: &flatExprs, localVarCst: &localVarCst))) {
831 assert(false && "flattening unimplemented for semi-affine integer sets");
832 return;
833 }
834 assert(flatExprs.size() == set.getNumConstraints());
835 insertVar(kind: VarKind::Local, pos: getNumVarKind(kind: VarKind::Local),
836 /*num=*/localVarCst.getNumLocalVars());
837
838 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
839 const auto &flatExpr = flatExprs[i];
840 assert(flatExpr.size() == getNumCols());
841 if (set.getEqFlags()[i]) {
842 addEquality(eq: flatExpr);
843 } else {
844 addInequality(inEq: flatExpr);
845 }
846 }
847 // Add the other constraints involving local vars from flattening.
848 append(other: localVarCst);
849}
850
851unsigned FlatLinearValueConstraints::appendDimVar(ValueRange vals) {
852 unsigned pos = getNumDimVars();
853 return insertVar(kind: VarKind::SetDim, pos, vals);
854}
855
856unsigned FlatLinearValueConstraints::appendSymbolVar(ValueRange vals) {
857 unsigned pos = getNumSymbolVars();
858 return insertVar(kind: VarKind::Symbol, pos, vals);
859}
860
861unsigned FlatLinearValueConstraints::insertDimVar(unsigned pos,
862 ValueRange vals) {
863 return insertVar(kind: VarKind::SetDim, pos, vals);
864}
865
866unsigned FlatLinearValueConstraints::insertSymbolVar(unsigned pos,
867 ValueRange vals) {
868 return insertVar(kind: VarKind::Symbol, pos, vals);
869}
870
871unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
872 unsigned num) {
873 unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
874
875 return absolutePos;
876}
877
878unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
879 ValueRange vals) {
880 assert(!vals.empty() && "expected ValueRange with Values.");
881 assert(kind != VarKind::Local &&
882 "values cannot be attached to local variables.");
883 unsigned num = vals.size();
884 unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
885
886 // If a Value is provided, insert it; otherwise use std::nullopt.
887 for (unsigned i = 0, e = vals.size(); i < e; ++i)
888 if (vals[i])
889 setValue(pos: absolutePos + i, val: vals[i]);
890
891 return absolutePos;
892}
893
894/// Checks if two constraint systems are in the same space, i.e., if they are
895/// associated with the same set of variables, appearing in the same order.
896static bool areVarsAligned(const FlatLinearValueConstraints &a,
897 const FlatLinearValueConstraints &b) {
898 if (a.getNumDomainVars() != b.getNumDomainVars() ||
899 a.getNumRangeVars() != b.getNumRangeVars() ||
900 a.getNumSymbolVars() != b.getNumSymbolVars())
901 return false;
902 SmallVector<std::optional<Value>> aMaybeValues = a.getMaybeValues(),
903 bMaybeValues = b.getMaybeValues();
904 return std::equal(first1: aMaybeValues.begin(), last1: aMaybeValues.end(),
905 first2: bMaybeValues.begin(), last2: bMaybeValues.end());
906}
907
908/// Calls areVarsAligned to check if two constraint systems have the same set
909/// of variables in the same order.
910bool FlatLinearValueConstraints::areVarsAlignedWithOther(
911 const FlatLinearConstraints &other) {
912 return areVarsAligned(a: *this, b: other);
913}
914
915/// Checks if the SSA values associated with `cst`'s variables in range
916/// [start, end) are unique.
917static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
918 const FlatLinearValueConstraints &cst, unsigned start, unsigned end) {
919
920 assert(start <= cst.getNumDimAndSymbolVars() &&
921 "Start position out of bounds");
922 assert(end <= cst.getNumDimAndSymbolVars() && "End position out of bounds");
923
924 if (start >= end)
925 return true;
926
927 SmallPtrSet<Value, 8> uniqueVars;
928 SmallVector<std::optional<Value>, 8> maybeValuesAll = cst.getMaybeValues();
929 ArrayRef<std::optional<Value>> maybeValues = {maybeValuesAll.data() + start,
930 maybeValuesAll.data() + end};
931
932 for (std::optional<Value> val : maybeValues)
933 if (val && !uniqueVars.insert(Ptr: *val).second)
934 return false;
935
936 return true;
937}
938
939/// Checks if the SSA values associated with `cst`'s variables are unique.
940static bool LLVM_ATTRIBUTE_UNUSED
941areVarsUnique(const FlatLinearValueConstraints &cst) {
942 return areVarsUnique(cst, start: 0, end: cst.getNumDimAndSymbolVars());
943}
944
945/// Checks if the SSA values associated with `cst`'s variables of kind `kind`
946/// are unique.
947static bool LLVM_ATTRIBUTE_UNUSED
948areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) {
949
950 if (kind == VarKind::SetDim)
951 return areVarsUnique(cst, start: 0, end: cst.getNumDimVars());
952 if (kind == VarKind::Symbol)
953 return areVarsUnique(cst, start: cst.getNumDimVars(),
954 end: cst.getNumDimAndSymbolVars());
955 llvm_unreachable("Unexpected VarKind");
956}
957
958/// Merge and align the variables of A and B starting at 'offset', so that
959/// both constraint systems get the union of the contained variables that is
960/// dimension-wise and symbol-wise unique; both constraint systems are updated
961/// so that they have the union of all variables, with A's original
962/// variables appearing first followed by any of B's variables that didn't
963/// appear in A. Local variables in B that have the same division
964/// representation as local variables in A are merged into one. We allow A
965/// and B to have non-unique values for their variables; in such cases, they are
966/// still aligned with the variables appearing first aligned with those
967/// appearing first in the other system from left to right.
968// E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
969// Output: both A, B have (%i, %j, %k) [%M, %N, %P]
970static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a,
971 FlatLinearValueConstraints *b) {
972 assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars());
973
974 assert(llvm::all_of(
975 llvm::drop_begin(a->getMaybeValues(), offset),
976 [](const std::optional<Value> &var) { return var.has_value(); }));
977
978 assert(llvm::all_of(
979 llvm::drop_begin(b->getMaybeValues(), offset),
980 [](const std::optional<Value> &var) { return var.has_value(); }));
981
982 SmallVector<Value, 4> aDimValues;
983 a->getValues(start: offset, end: a->getNumDimVars(), values: &aDimValues);
984
985 {
986 // Merge dims from A into B.
987 unsigned d = offset;
988 for (Value aDimValue : aDimValues) {
989 unsigned loc;
990 // Find from the position `d` since we'd like to also consider the
991 // possibility of multiple variables with the same `Value`. We align with
992 // the next appearing one.
993 if (b->findVar(val: aDimValue, pos: &loc, offset: d)) {
994 assert(loc >= offset && "A's dim appears in B's aligned range");
995 assert(loc < b->getNumDimVars() &&
996 "A's dim appears in B's non-dim position");
997 b->swapVar(posA: d, posB: loc);
998 } else {
999 b->insertDimVar(pos: d, vals: aDimValue);
1000 }
1001 d++;
1002 }
1003 // Dimensions that are in B, but not in A, are added at the end.
1004 for (unsigned t = a->getNumDimVars(), e = b->getNumDimVars(); t < e; t++) {
1005 a->appendDimVar(vals: b->getValue(pos: t));
1006 }
1007 assert(a->getNumDimVars() == b->getNumDimVars() &&
1008 "expected same number of dims");
1009 }
1010
1011 // Merge and align symbols of A and B
1012 a->mergeSymbolVars(other&: *b);
1013 // Merge and align locals of A and B
1014 a->mergeLocalVars(other&: *b);
1015
1016 assert(areVarsAligned(*a, *b) && "IDs expected to be aligned");
1017}
1018
1019// Call 'mergeAndAlignVars' to align constraint systems of 'this' and 'other'.
1020void FlatLinearValueConstraints::mergeAndAlignVarsWithOther(
1021 unsigned offset, FlatLinearValueConstraints *other) {
1022 mergeAndAlignVars(offset, a: this, b: other);
1023}
1024
1025/// Merge and align symbols of `this` and `other` such that both get union of
1026/// of symbols. Existing symbols need not be unique; they will be aligned from
1027/// left to right with duplicates aligned in the same order. Symbols with Value
1028/// as `None` are considered to be inequal to all other symbols.
1029void FlatLinearValueConstraints::mergeSymbolVars(
1030 FlatLinearValueConstraints &other) {
1031
1032 SmallVector<Value, 4> aSymValues;
1033 getValues(start: getNumDimVars(), end: getNumDimAndSymbolVars(), values: &aSymValues);
1034
1035 // Merge symbols: merge symbols into `other` first from `this`.
1036 unsigned s = other.getNumDimVars();
1037 for (Value aSymValue : aSymValues) {
1038 unsigned loc;
1039 // If the var is a symbol in `other`, then align it, otherwise assume that
1040 // it is a new symbol. Search in `other` starting at position `s` since the
1041 // left of it is aligned.
1042 if (other.findVar(val: aSymValue, pos: &loc, offset: s) && loc >= other.getNumDimVars() &&
1043 loc < other.getNumDimAndSymbolVars())
1044 other.swapVar(posA: s, posB: loc);
1045 else
1046 other.insertSymbolVar(pos: s - other.getNumDimVars(), vals: aSymValue);
1047 s++;
1048 }
1049
1050 // Symbols that are in other, but not in this, are added at the end.
1051 for (unsigned t = other.getNumDimVars() + getNumSymbolVars(),
1052 e = other.getNumDimAndSymbolVars();
1053 t < e; t++)
1054 insertSymbolVar(pos: getNumSymbolVars(), vals: other.getValue(pos: t));
1055
1056 assert(getNumSymbolVars() == other.getNumSymbolVars() &&
1057 "expected same number of symbols");
1058}
1059
1060void FlatLinearValueConstraints::removeVarRange(VarKind kind, unsigned varStart,
1061 unsigned varLimit) {
1062 IntegerPolyhedron::removeVarRange(kind, varStart, varLimit);
1063}
1064
1065AffineMap
1066FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
1067 ValueRange operands) const {
1068 assert(map.getNumInputs() == operands.size() && "number of inputs mismatch");
1069
1070 SmallVector<Value> dims, syms;
1071#ifndef NDEBUG
1072 SmallVector<Value> newSyms;
1073 SmallVector<Value> *newSymsPtr = &newSyms;
1074#else
1075 SmallVector<Value> *newSymsPtr = nullptr;
1076#endif // NDEBUG
1077
1078 dims.reserve(N: getNumDimVars());
1079 syms.reserve(N: getNumSymbolVars());
1080 for (unsigned i = 0, e = getNumVarKind(kind: VarKind::SetDim); i < e; ++i) {
1081 Identifier id = space.getId(kind: VarKind::SetDim, pos: i);
1082 dims.push_back(Elt: id.hasValue() ? Value(id.getValue<Value>()) : Value());
1083 }
1084 for (unsigned i = 0, e = getNumVarKind(kind: VarKind::Symbol); i < e; ++i) {
1085 Identifier id = space.getId(kind: VarKind::Symbol, pos: i);
1086 syms.push_back(Elt: id.hasValue() ? Value(id.getValue<Value>()) : Value());
1087 }
1088
1089 AffineMap alignedMap =
1090 alignAffineMapWithValues(map, operands, dims, syms, newSyms: newSymsPtr);
1091 // All symbols are already part of this FlatAffineValueConstraints.
1092 assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
1093 assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
1094 "unexpected new/missing symbols");
1095 return alignedMap;
1096}
1097
1098bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
1099 unsigned offset) const {
1100 SmallVector<std::optional<Value>> maybeValues = getMaybeValues();
1101 for (unsigned i = offset, e = maybeValues.size(); i < e; ++i)
1102 if (maybeValues[i] && maybeValues[i].value() == val) {
1103 *pos = i;
1104 return true;
1105 }
1106 return false;
1107}
1108
1109bool FlatLinearValueConstraints::containsVar(Value val) const {
1110 unsigned pos;
1111 return findVar(val, pos: &pos, offset: 0);
1112}
1113
1114void FlatLinearValueConstraints::addBound(BoundType type, Value val,
1115 int64_t value) {
1116 unsigned pos;
1117 if (!findVar(val, pos: &pos))
1118 // This is a pre-condition for this method.
1119 assert(0 && "var not found");
1120 addBound(type, pos, value);
1121}
1122
1123void FlatLinearConstraints::printSpace(raw_ostream &os) const {
1124 IntegerPolyhedron::printSpace(os);
1125 os << "(";
1126 for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++)
1127 os << "None\t";
1128 for (unsigned i = getVarKindOffset(kind: VarKind::Local),
1129 e = getVarKindEnd(kind: VarKind::Local);
1130 i < e; ++i)
1131 os << "Local\t";
1132 os << "const)\n";
1133}
1134
1135void FlatLinearValueConstraints::printSpace(raw_ostream &os) const {
1136 IntegerPolyhedron::printSpace(os);
1137 os << "(";
1138 for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) {
1139 if (hasValue(pos: i))
1140 os << "Value\t";
1141 else
1142 os << "None\t";
1143 }
1144 for (unsigned i = getVarKindOffset(kind: VarKind::Local),
1145 e = getVarKindEnd(kind: VarKind::Local);
1146 i < e; ++i)
1147 os << "Local\t";
1148 os << "const)\n";
1149}
1150
1151void FlatLinearValueConstraints::projectOut(Value val) {
1152 unsigned pos;
1153 bool ret = findVar(val, pos: &pos);
1154 assert(ret);
1155 (void)ret;
1156 fourierMotzkinEliminate(pos);
1157}
1158
1159LogicalResult FlatLinearValueConstraints::unionBoundingBox(
1160 const FlatLinearValueConstraints &otherCst) {
1161 assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch");
1162 SmallVector<std::optional<Value>> maybeValues = getMaybeValues(),
1163 otherMaybeValues =
1164 otherCst.getMaybeValues();
1165 assert(std::equal(maybeValues.begin(), maybeValues.begin() + getNumDimVars(),
1166 otherMaybeValues.begin(),
1167 otherMaybeValues.begin() + getNumDimVars()) &&
1168 "dim values mismatch");
1169 assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
1170 assert(getNumLocalVars() == 0 && "local vars not supported yet here");
1171
1172 // Align `other` to this.
1173 if (!areVarsAligned(a: *this, b: otherCst)) {
1174 FlatLinearValueConstraints otherCopy(otherCst);
1175 mergeAndAlignVars(/*offset=*/getNumDimVars(), a: this, b: &otherCopy);
1176 return IntegerPolyhedron::unionBoundingBox(other: otherCopy);
1177 }
1178
1179 return IntegerPolyhedron::unionBoundingBox(other: otherCst);
1180}
1181
1182//===----------------------------------------------------------------------===//
1183// Helper functions
1184//===----------------------------------------------------------------------===//
1185
1186AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
1187 ValueRange dims, ValueRange syms,
1188 SmallVector<Value> *newSyms) {
1189 assert(operands.size() == map.getNumInputs() &&
1190 "expected same number of operands and map inputs");
1191 MLIRContext *ctx = map.getContext();
1192 Builder builder(ctx);
1193 SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {});
1194 unsigned numSymbols = syms.size();
1195 SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {});
1196 if (newSyms) {
1197 newSyms->clear();
1198 newSyms->append(in_start: syms.begin(), in_end: syms.end());
1199 }
1200
1201 for (const auto &operand : llvm::enumerate(First&: operands)) {
1202 // Compute replacement dim/sym of operand.
1203 AffineExpr replacement;
1204 auto dimIt = llvm::find(Range&: dims, Val: operand.value());
1205 auto symIt = llvm::find(Range&: syms, Val: operand.value());
1206 if (dimIt != dims.end()) {
1207 replacement =
1208 builder.getAffineDimExpr(position: std::distance(first: dims.begin(), last: dimIt));
1209 } else if (symIt != syms.end()) {
1210 replacement =
1211 builder.getAffineSymbolExpr(position: std::distance(first: syms.begin(), last: symIt));
1212 } else {
1213 // This operand is neither a dimension nor a symbol. Add it as a new
1214 // symbol.
1215 replacement = builder.getAffineSymbolExpr(position: numSymbols++);
1216 if (newSyms)
1217 newSyms->push_back(Elt: operand.value());
1218 }
1219 // Add to corresponding replacements vector.
1220 if (operand.index() < map.getNumDims()) {
1221 dimReplacements[operand.index()] = replacement;
1222 } else {
1223 symReplacements[operand.index() - map.getNumDims()] = replacement;
1224 }
1225 }
1226
1227 return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
1228 numResultDims: dims.size(), numResultSyms: numSymbols);
1229}
1230
1231LogicalResult
1232mlir::getMultiAffineFunctionFromMap(AffineMap map,
1233 MultiAffineFunction &multiAff) {
1234 FlatLinearConstraints cst;
1235 std::vector<SmallVector<int64_t, 8>> flattenedExprs;
1236 LogicalResult result = getFlattenedAffineExprs(map, flattenedExprs: &flattenedExprs, localVarCst: &cst);
1237
1238 if (result.failed())
1239 return failure();
1240
1241 DivisionRepr divs = cst.getLocalReprs();
1242 assert(divs.hasAllReprs() &&
1243 "AffineMap cannot produce divs without local representation");
1244
1245 // TODO: We shouldn't have to do this conversion.
1246 Matrix<MPInt> mat(map.getNumResults(), map.getNumInputs() + divs.getNumDivs() + 1);
1247 for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i)
1248 for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j)
1249 mat(i, j) = flattenedExprs[i][j];
1250
1251 multiAff = MultiAffineFunction(
1252 PresburgerSpace::getRelationSpace(numDomain: map.getNumDims(), numRange: map.getNumResults(),
1253 numSymbols: map.getNumSymbols(), numLocals: divs.getNumDivs()),
1254 mat, divs);
1255
1256 return success();
1257}
1258

source code of mlir/lib/Analysis/FlatLinearValueConstraints.cpp