1 | //===- FlatLinearValueConstraints.cpp - Linear Constraint -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis//FlatLinearValueConstraints.h" |
10 | |
11 | #include "mlir/Analysis/Presburger/LinearTransform.h" |
12 | #include "mlir/Analysis/Presburger/PresburgerSpace.h" |
13 | #include "mlir/Analysis/Presburger/Simplex.h" |
14 | #include "mlir/Analysis/Presburger/Utils.h" |
15 | #include "mlir/IR/AffineExprVisitor.h" |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/IntegerSet.h" |
18 | #include "mlir/Support/LLVM.h" |
19 | #include "mlir/Support/MathExtras.h" |
20 | #include "llvm/ADT/STLExtras.h" |
21 | #include "llvm/ADT/SmallPtrSet.h" |
22 | #include "llvm/ADT/SmallVector.h" |
23 | #include "llvm/Support/Debug.h" |
24 | #include "llvm/Support/raw_ostream.h" |
25 | #include <optional> |
26 | |
27 | #define DEBUG_TYPE "flat-value-constraints" |
28 | |
29 | using namespace mlir; |
30 | using namespace presburger; |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // AffineExprFlattener |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | namespace { |
37 | |
38 | // See comments for SimpleAffineExprFlattener. |
39 | // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording |
40 | // constraint information associated with mod's, floordiv's, and ceildiv's |
41 | // in FlatLinearConstraints 'localVarCst'. |
42 | struct AffineExprFlattener : public SimpleAffineExprFlattener { |
43 | public: |
44 | // Constraints connecting newly introduced local variables (for mod's and |
45 | // div's) to existing (dimensional and symbolic) ones. These are always |
46 | // inequalities. |
47 | IntegerPolyhedron localVarCst; |
48 | |
49 | AffineExprFlattener(unsigned nDims, unsigned nSymbols) |
50 | : SimpleAffineExprFlattener(nDims, nSymbols), |
51 | localVarCst(PresburgerSpace::getSetSpace(numDims: nDims, numSymbols: nSymbols)) {} |
52 | |
53 | private: |
54 | // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr). |
55 | // The local variable added is always a floordiv of a pure add/mul affine |
56 | // function of other variables, coefficients of which are specified in |
57 | // `dividend' and with respect to the positive constant `divisor'. localExpr |
58 | // is the simplified tree expression (AffineExpr) corresponding to the |
59 | // quantifier. |
60 | void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor, |
61 | AffineExpr localExpr) override { |
62 | SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr); |
63 | // Update localVarCst. |
64 | localVarCst.addLocalFloorDiv(dividend, divisor); |
65 | } |
66 | }; |
67 | |
68 | } // namespace |
69 | |
70 | // Flattens the expressions in map. Returns failure if 'expr' was unable to be |
71 | // flattened. For example two specific cases: |
72 | // 1. semi-affine expressions not handled yet. |
73 | // 2. has poison expression (i.e., division by zero). |
74 | static LogicalResult |
75 | getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims, |
76 | unsigned numSymbols, |
77 | std::vector<SmallVector<int64_t, 8>> *flattenedExprs, |
78 | FlatLinearConstraints *localVarCst) { |
79 | if (exprs.empty()) { |
80 | if (localVarCst) |
81 | *localVarCst = FlatLinearConstraints(numDims, numSymbols); |
82 | return success(); |
83 | } |
84 | |
85 | AffineExprFlattener flattener(numDims, numSymbols); |
86 | // Use the same flattener to simplify each expression successively. This way |
87 | // local variables / expressions are shared. |
88 | for (auto expr : exprs) { |
89 | if (!expr.isPureAffine()) |
90 | return failure(); |
91 | // has poison expression |
92 | auto flattenResult = flattener.walkPostOrder(expr); |
93 | if (failed(result: flattenResult)) |
94 | return failure(); |
95 | } |
96 | |
97 | assert(flattener.operandExprStack.size() == exprs.size()); |
98 | flattenedExprs->clear(); |
99 | flattenedExprs->assign(first: flattener.operandExprStack.begin(), |
100 | last: flattener.operandExprStack.end()); |
101 | |
102 | if (localVarCst) |
103 | localVarCst->clearAndCopyFrom(other: flattener.localVarCst); |
104 | |
105 | return success(); |
106 | } |
107 | |
108 | // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to |
109 | // be flattened (semi-affine expressions not handled yet). |
110 | LogicalResult |
111 | mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, |
112 | unsigned numSymbols, |
113 | SmallVectorImpl<int64_t> *flattenedExpr, |
114 | FlatLinearConstraints *localVarCst) { |
115 | std::vector<SmallVector<int64_t, 8>> flattenedExprs; |
116 | LogicalResult ret = ::getFlattenedAffineExprs(exprs: {expr}, numDims, numSymbols, |
117 | flattenedExprs: &flattenedExprs, localVarCst); |
118 | *flattenedExpr = flattenedExprs[0]; |
119 | return ret; |
120 | } |
121 | |
122 | /// Flattens the expressions in map. Returns failure if 'expr' was unable to be |
123 | /// flattened (i.e., semi-affine expressions not handled yet). |
124 | LogicalResult mlir::getFlattenedAffineExprs( |
125 | AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, |
126 | FlatLinearConstraints *localVarCst) { |
127 | if (map.getNumResults() == 0) { |
128 | if (localVarCst) |
129 | *localVarCst = |
130 | FlatLinearConstraints(map.getNumDims(), map.getNumSymbols()); |
131 | return success(); |
132 | } |
133 | return ::getFlattenedAffineExprs(exprs: map.getResults(), numDims: map.getNumDims(), |
134 | numSymbols: map.getNumSymbols(), flattenedExprs, |
135 | localVarCst); |
136 | } |
137 | |
138 | LogicalResult mlir::getFlattenedAffineExprs( |
139 | IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, |
140 | FlatLinearConstraints *localVarCst) { |
141 | if (set.getNumConstraints() == 0) { |
142 | if (localVarCst) |
143 | *localVarCst = |
144 | FlatLinearConstraints(set.getNumDims(), set.getNumSymbols()); |
145 | return success(); |
146 | } |
147 | return ::getFlattenedAffineExprs(exprs: set.getConstraints(), numDims: set.getNumDims(), |
148 | numSymbols: set.getNumSymbols(), flattenedExprs, |
149 | localVarCst); |
150 | } |
151 | |
152 | //===----------------------------------------------------------------------===// |
153 | // FlatLinearConstraints |
154 | //===----------------------------------------------------------------------===// |
155 | |
156 | // Similar to `composeMap` except that no Values need be associated with the |
157 | // constraint system nor are they looked at -- the dimensions and symbols of |
158 | // `other` are expected to correspond 1:1 to `this` system. |
159 | LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) { |
160 | assert(other.getNumDims() == getNumDimVars() && "dim mismatch" ); |
161 | assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch" ); |
162 | |
163 | std::vector<SmallVector<int64_t, 8>> flatExprs; |
164 | if (failed(result: flattenAlignedMapAndMergeLocals(map: other, flattenedExprs: &flatExprs))) |
165 | return failure(); |
166 | assert(flatExprs.size() == other.getNumResults()); |
167 | |
168 | // Add dimensions corresponding to the map's results. |
169 | insertDimVar(/*pos=*/0, /*num=*/other.getNumResults()); |
170 | |
171 | // We add one equality for each result connecting the result dim of the map to |
172 | // the other variables. |
173 | // E.g.: if the expression is 16*i0 + i1, and this is the r^th |
174 | // iteration/result of the value map, we are adding the equality: |
175 | // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we |
176 | // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. |
177 | for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { |
178 | const auto &flatExpr = flatExprs[r]; |
179 | assert(flatExpr.size() >= other.getNumInputs() + 1); |
180 | |
181 | SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0); |
182 | // Set the coefficient for this result to one. |
183 | eqToAdd[r] = 1; |
184 | |
185 | // Dims and symbols. |
186 | for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) { |
187 | // Negate `eq[r]` since the newly added dimension will be set to this one. |
188 | eqToAdd[e + i] = -flatExpr[i]; |
189 | } |
190 | // Local columns of `eq` are at the beginning. |
191 | unsigned j = getNumDimVars() + getNumSymbolVars(); |
192 | unsigned end = flatExpr.size() - 1; |
193 | for (unsigned i = other.getNumInputs(); i < end; i++, j++) { |
194 | eqToAdd[j] = -flatExpr[i]; |
195 | } |
196 | |
197 | // Constant term. |
198 | eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; |
199 | |
200 | // Add the equality connecting the result of the map to this constraint set. |
201 | addEquality(eq: eqToAdd); |
202 | } |
203 | |
204 | return success(); |
205 | } |
206 | |
207 | // Determine whether the variable at 'pos' (say var_r) can be expressed as |
208 | // modulo of another known variable (say var_n) w.r.t a constant. For example, |
209 | // if the following constraints hold true: |
210 | // ``` |
211 | // 0 <= var_r <= divisor - 1 |
212 | // var_n - (divisor * q_expr) = var_r |
213 | // ``` |
214 | // where `var_n` is a known variable (called dividend), and `q_expr` is an |
215 | // `AffineExpr` (called the quotient expression), `var_r` can be written as: |
216 | // |
217 | // `var_r = var_n mod divisor`. |
218 | // |
219 | // Additionally, in a special case of the above constaints where `q_expr` is an |
220 | // variable itself that is not yet known (say `var_q`), it can be written as a |
221 | // floordiv in the following way: |
222 | // |
223 | // `var_q = var_n floordiv divisor`. |
224 | // |
225 | // First 'num' dimensional variables starting at 'offset' are |
226 | // derived/to-be-derived in terms of the remaining variables. The remaining |
227 | // variables are assigned trivial affine expressions in `memo`. For example, |
228 | // memo is initilized as follows for a `cst` with 5 dims, when offset=2, num=2: |
229 | // memo ==> d0 d1 . . d2 ... |
230 | // cst ==> c0 c1 c2 c3 c4 ... |
231 | // |
232 | // Returns true if the above mod or floordiv are detected, updating 'memo' with |
233 | // these new expressions. Returns false otherwise. |
234 | static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos, |
235 | unsigned offset, unsigned num, int64_t lbConst, |
236 | int64_t ubConst, MLIRContext *context, |
237 | SmallVectorImpl<AffineExpr> &memo) { |
238 | assert(pos < cst.getNumVars() && "invalid position" ); |
239 | |
240 | // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can |
241 | // be determined. |
242 | if (lbConst != 0 || ubConst < 1) |
243 | return false; |
244 | int64_t divisor = ubConst + 1; |
245 | |
246 | // Check for the aforementioned conditions in each equality. |
247 | for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities(); |
248 | curEquality < numEqualities; curEquality++) { |
249 | int64_t coefficientAtPos = cst.atEq64(i: curEquality, j: pos); |
250 | // If current equality does not involve `var_r`, continue to the next |
251 | // equality. |
252 | if (coefficientAtPos == 0) |
253 | continue; |
254 | |
255 | // Constant term should be 0 in this equality. |
256 | if (cst.atEq64(i: curEquality, j: cst.getNumCols() - 1) != 0) |
257 | continue; |
258 | |
259 | // Traverse through the equality and construct the dividend expression |
260 | // `dividendExpr`, to contain all the variables which are known and are |
261 | // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the |
262 | // `dividendExpr` gets simplified into a single variable `var_n` discussed |
263 | // above. |
264 | auto dividendExpr = getAffineConstantExpr(constant: 0, context); |
265 | |
266 | // Track the terms that go into quotient expression, later used to detect |
267 | // additional floordiv. |
268 | unsigned quotientCount = 0; |
269 | int quotientPosition = -1; |
270 | int quotientSign = 1; |
271 | |
272 | // Consider each term in the current equality. |
273 | unsigned curVar, e; |
274 | for (curVar = 0, e = cst.getNumDimAndSymbolVars(); curVar < e; ++curVar) { |
275 | // Ignore var_r. |
276 | if (curVar == pos) |
277 | continue; |
278 | int64_t coefficientOfCurVar = cst.atEq64(i: curEquality, j: curVar); |
279 | // Ignore vars that do not contribute to the current equality. |
280 | if (coefficientOfCurVar == 0) |
281 | continue; |
282 | // Check if the current var goes into the quotient expression. |
283 | if (coefficientOfCurVar % (divisor * coefficientAtPos) == 0) { |
284 | quotientCount++; |
285 | quotientPosition = curVar; |
286 | quotientSign = (coefficientOfCurVar * coefficientAtPos) > 0 ? 1 : -1; |
287 | continue; |
288 | } |
289 | // Variables that are part of dividendExpr should be known. |
290 | if (!memo[curVar]) |
291 | break; |
292 | // Append the current variable to the dividend expression. |
293 | dividendExpr = dividendExpr + memo[curVar] * coefficientOfCurVar; |
294 | } |
295 | |
296 | // Can't construct expression as it depends on a yet uncomputed var. |
297 | if (curVar < e) |
298 | continue; |
299 | |
300 | // Express `var_r` in terms of the other vars collected so far. |
301 | if (coefficientAtPos > 0) |
302 | dividendExpr = (-dividendExpr).floorDiv(v: coefficientAtPos); |
303 | else |
304 | dividendExpr = dividendExpr.floorDiv(v: -coefficientAtPos); |
305 | |
306 | // Simplify the expression. |
307 | dividendExpr = simplifyAffineExpr(expr: dividendExpr, numDims: cst.getNumDimVars(), |
308 | numSymbols: cst.getNumSymbolVars()); |
309 | // Only if the final dividend expression is just a single var (which we call |
310 | // `var_n`), we can proceed. |
311 | // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it |
312 | // to dims themselves. |
313 | auto dimExpr = dyn_cast<AffineDimExpr>(Val&: dividendExpr); |
314 | if (!dimExpr) |
315 | continue; |
316 | |
317 | // Express `var_r` as `var_n % divisor` and store the expression in `memo`. |
318 | if (quotientCount >= 1) { |
319 | // Find the column corresponding to `dimExpr`. `num` columns starting at |
320 | // `offset` correspond to previously unknown variables. The column |
321 | // corresponding to the trivially known `dimExpr` can be on either side |
322 | // of these. |
323 | unsigned dimExprPos = dimExpr.getPosition(); |
324 | unsigned dimExprCol = dimExprPos < offset ? dimExprPos : dimExprPos + num; |
325 | auto ub = cst.getConstantBound64(type: BoundType::UB, pos: dimExprCol); |
326 | // If `var_n` has an upperbound that is less than the divisor, mod can be |
327 | // eliminated altogether. |
328 | if (ub && *ub < divisor) |
329 | memo[pos] = dimExpr; |
330 | else |
331 | memo[pos] = dimExpr % divisor; |
332 | // If a unique quotient `var_q` was seen, it can be expressed as |
333 | // `var_n floordiv divisor`. |
334 | if (quotientCount == 1 && !memo[quotientPosition]) |
335 | memo[quotientPosition] = dimExpr.floorDiv(v: divisor) * quotientSign; |
336 | |
337 | return true; |
338 | } |
339 | } |
340 | return false; |
341 | } |
342 | |
343 | /// Check if the pos^th variable can be expressed as a floordiv of an affine |
344 | /// function of other variables (where the divisor is a positive constant) |
345 | /// given the initial set of expressions in `exprs`. If it can be, the |
346 | /// corresponding position in `exprs` is set as the detected affine expr. For |
347 | /// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can |
348 | /// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28 |
349 | /// <= i <= 32q + 31 => q = i floordiv 32. |
350 | static bool detectAsFloorDiv(const FlatLinearConstraints &cst, unsigned pos, |
351 | MLIRContext *context, |
352 | SmallVectorImpl<AffineExpr> &exprs) { |
353 | assert(pos < cst.getNumVars() && "invalid position" ); |
354 | |
355 | // Get upper-lower bound pair for this variable. |
356 | SmallVector<bool, 8> foundRepr(cst.getNumVars(), false); |
357 | for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i) |
358 | if (exprs[i]) |
359 | foundRepr[i] = true; |
360 | |
361 | SmallVector<int64_t, 8> dividend(cst.getNumCols()); |
362 | unsigned divisor; |
363 | auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor); |
364 | |
365 | // No upper-lower bound pair found for this var. |
366 | if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality) |
367 | return false; |
368 | |
369 | // Construct the dividend expression. |
370 | auto dividendExpr = getAffineConstantExpr(constant: dividend.back(), context); |
371 | for (unsigned c = 0, f = cst.getNumVars(); c < f; c++) |
372 | if (dividend[c] != 0) |
373 | dividendExpr = dividendExpr + dividend[c] * exprs[c]; |
374 | |
375 | // Successfully detected the floordiv. |
376 | exprs[pos] = dividendExpr.floorDiv(v: divisor); |
377 | return true; |
378 | } |
379 | |
380 | std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound( |
381 | unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, |
382 | ArrayRef<AffineExpr> localExprs, MLIRContext *context, |
383 | bool closedUB) const { |
384 | assert(pos + offset < getNumDimVars() && "invalid dim start pos" ); |
385 | assert(symStartPos >= (pos + offset) && "invalid sym start pos" ); |
386 | assert(getNumLocalVars() == localExprs.size() && |
387 | "incorrect local exprs count" ); |
388 | |
389 | SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices; |
390 | getLowerAndUpperBoundIndices(pos: pos + offset, lbIndices: &lbIndices, ubIndices: &ubIndices, eqIndices: &eqIndices, |
391 | offset, num); |
392 | |
393 | /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos). |
394 | auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) { |
395 | b.clear(); |
396 | for (unsigned i = 0, e = a.size(); i < e; ++i) { |
397 | if (i < offset || i >= offset + num) |
398 | b.push_back(Elt: a[i]); |
399 | } |
400 | }; |
401 | |
402 | SmallVector<int64_t, 8> lb, ub; |
403 | SmallVector<AffineExpr, 4> lbExprs; |
404 | unsigned dimCount = symStartPos - num; |
405 | unsigned symCount = getNumDimAndSymbolVars() - symStartPos; |
406 | lbExprs.reserve(N: lbIndices.size() + eqIndices.size()); |
407 | // Lower bound expressions. |
408 | for (auto idx : lbIndices) { |
409 | auto ineq = getInequality64(idx); |
410 | // Extract the lower bound (in terms of other coeff's + const), i.e., if |
411 | // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j |
412 | // - 1. |
413 | addCoeffs(ineq, lb); |
414 | std::transform(first: lb.begin(), last: lb.end(), result: lb.begin(), unary_op: std::negate<int64_t>()); |
415 | auto expr = |
416 | getAffineExprFromFlatForm(flatExprs: lb, numDims: dimCount, numSymbols: symCount, localExprs, context); |
417 | // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor |
418 | int64_t divisor = std::abs(i: ineq[pos + offset]); |
419 | expr = (expr + divisor - 1).floorDiv(v: divisor); |
420 | lbExprs.push_back(Elt: expr); |
421 | } |
422 | |
423 | SmallVector<AffineExpr, 4> ubExprs; |
424 | ubExprs.reserve(N: ubIndices.size() + eqIndices.size()); |
425 | // Upper bound expressions. |
426 | for (auto idx : ubIndices) { |
427 | auto ineq = getInequality64(idx); |
428 | // Extract the upper bound (in terms of other coeff's + const). |
429 | addCoeffs(ineq, ub); |
430 | auto expr = |
431 | getAffineExprFromFlatForm(flatExprs: ub, numDims: dimCount, numSymbols: symCount, localExprs, context); |
432 | expr = expr.floorDiv(v: std::abs(i: ineq[pos + offset])); |
433 | int64_t ubAdjustment = closedUB ? 0 : 1; |
434 | ubExprs.push_back(Elt: expr + ubAdjustment); |
435 | } |
436 | |
437 | // Equalities. It's both a lower and a upper bound. |
438 | SmallVector<int64_t, 4> b; |
439 | for (auto idx : eqIndices) { |
440 | auto eq = getEquality64(idx); |
441 | addCoeffs(eq, b); |
442 | if (eq[pos + offset] > 0) |
443 | std::transform(first: b.begin(), last: b.end(), result: b.begin(), unary_op: std::negate<int64_t>()); |
444 | |
445 | // Extract the upper bound (in terms of other coeff's + const). |
446 | auto expr = |
447 | getAffineExprFromFlatForm(flatExprs: b, numDims: dimCount, numSymbols: symCount, localExprs, context); |
448 | expr = expr.floorDiv(v: std::abs(i: eq[pos + offset])); |
449 | // Upper bound is exclusive. |
450 | ubExprs.push_back(Elt: expr + 1); |
451 | // Lower bound. |
452 | expr = |
453 | getAffineExprFromFlatForm(flatExprs: b, numDims: dimCount, numSymbols: symCount, localExprs, context); |
454 | expr = expr.ceilDiv(v: std::abs(i: eq[pos + offset])); |
455 | lbExprs.push_back(Elt: expr); |
456 | } |
457 | |
458 | auto lbMap = AffineMap::get(dimCount, symbolCount: symCount, results: lbExprs, context); |
459 | auto ubMap = AffineMap::get(dimCount, symbolCount: symCount, results: ubExprs, context); |
460 | |
461 | return {lbMap, ubMap}; |
462 | } |
463 | |
464 | /// Computes the lower and upper bounds of the first 'num' dimensional |
465 | /// variables (starting at 'offset') as affine maps of the remaining |
466 | /// variables (dimensional and symbolic variables). Local variables are |
467 | /// themselves explicitly computed as affine functions of other variables in |
468 | /// this process if needed. |
469 | void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num, |
470 | MLIRContext *context, |
471 | SmallVectorImpl<AffineMap> *lbMaps, |
472 | SmallVectorImpl<AffineMap> *ubMaps, |
473 | bool closedUB) { |
474 | assert(offset + num <= getNumDimVars() && "invalid range" ); |
475 | |
476 | // Basic simplification. |
477 | normalizeConstraintsByGCD(); |
478 | |
479 | LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num |
480 | << " variables\n" ); |
481 | LLVM_DEBUG(dump()); |
482 | |
483 | // Record computed/detected variables. |
484 | SmallVector<AffineExpr, 8> memo(getNumVars()); |
485 | // Initialize dimensional and symbolic variables. |
486 | for (unsigned i = 0, e = getNumDimVars(); i < e; i++) { |
487 | if (i < offset) |
488 | memo[i] = getAffineDimExpr(position: i, context); |
489 | else if (i >= offset + num) |
490 | memo[i] = getAffineDimExpr(position: i - num, context); |
491 | } |
492 | for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++) |
493 | memo[i] = getAffineSymbolExpr(position: i - getNumDimVars(), context); |
494 | |
495 | bool changed; |
496 | do { |
497 | changed = false; |
498 | // Identify yet unknown variables as constants or mod's / floordiv's of |
499 | // other variables if possible. |
500 | for (unsigned pos = 0; pos < getNumVars(); pos++) { |
501 | if (memo[pos]) |
502 | continue; |
503 | |
504 | auto lbConst = getConstantBound64(type: BoundType::LB, pos); |
505 | auto ubConst = getConstantBound64(type: BoundType::UB, pos); |
506 | if (lbConst.has_value() && ubConst.has_value()) { |
507 | // Detect equality to a constant. |
508 | if (*lbConst == *ubConst) { |
509 | memo[pos] = getAffineConstantExpr(constant: *lbConst, context); |
510 | changed = true; |
511 | continue; |
512 | } |
513 | |
514 | // Detect a variable as modulo of another variable w.r.t a |
515 | // constant. |
516 | if (detectAsMod(cst: *this, pos, offset, num, lbConst: *lbConst, ubConst: *ubConst, context, |
517 | memo)) { |
518 | changed = true; |
519 | continue; |
520 | } |
521 | } |
522 | |
523 | // Detect a variable as a floordiv of an affine function of other |
524 | // variables (divisor is a positive constant). |
525 | if (detectAsFloorDiv(cst: *this, pos, context, exprs&: memo)) { |
526 | changed = true; |
527 | continue; |
528 | } |
529 | |
530 | // Detect a variable as an expression of other variables. |
531 | unsigned idx; |
532 | if (!findConstraintWithNonZeroAt(colIdx: pos, /*isEq=*/true, rowIdx: &idx)) { |
533 | continue; |
534 | } |
535 | |
536 | // Build AffineExpr solving for variable 'pos' in terms of all others. |
537 | auto expr = getAffineConstantExpr(constant: 0, context); |
538 | unsigned j, e; |
539 | for (j = 0, e = getNumVars(); j < e; ++j) { |
540 | if (j == pos) |
541 | continue; |
542 | int64_t c = atEq64(i: idx, j); |
543 | if (c == 0) |
544 | continue; |
545 | // If any of the involved IDs hasn't been found yet, we can't proceed. |
546 | if (!memo[j]) |
547 | break; |
548 | expr = expr + memo[j] * c; |
549 | } |
550 | if (j < e) |
551 | // Can't construct expression as it depends on a yet uncomputed |
552 | // variable. |
553 | continue; |
554 | |
555 | // Add constant term to AffineExpr. |
556 | expr = expr + atEq64(i: idx, j: getNumVars()); |
557 | int64_t vPos = atEq64(i: idx, j: pos); |
558 | assert(vPos != 0 && "expected non-zero here" ); |
559 | if (vPos > 0) |
560 | expr = (-expr).floorDiv(v: vPos); |
561 | else |
562 | // vPos < 0. |
563 | expr = expr.floorDiv(v: -vPos); |
564 | // Successfully constructed expression. |
565 | memo[pos] = expr; |
566 | changed = true; |
567 | } |
568 | // This loop is guaranteed to reach a fixed point - since once an |
569 | // variable's explicit form is computed (in memo[pos]), it's not updated |
570 | // again. |
571 | } while (changed); |
572 | |
573 | int64_t ubAdjustment = closedUB ? 0 : 1; |
574 | |
575 | // Set the lower and upper bound maps for all the variables that were |
576 | // computed as affine expressions of the rest as the "detected expr" and |
577 | // "detected expr + 1" respectively; set the undetected ones to null. |
578 | std::optional<FlatLinearConstraints> tmpClone; |
579 | for (unsigned pos = 0; pos < num; pos++) { |
580 | unsigned numMapDims = getNumDimVars() - num; |
581 | unsigned numMapSymbols = getNumSymbolVars(); |
582 | AffineExpr expr = memo[pos + offset]; |
583 | if (expr) |
584 | expr = simplifyAffineExpr(expr, numDims: numMapDims, numSymbols: numMapSymbols); |
585 | |
586 | AffineMap &lbMap = (*lbMaps)[pos]; |
587 | AffineMap &ubMap = (*ubMaps)[pos]; |
588 | |
589 | if (expr) { |
590 | lbMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols, result: expr); |
591 | ubMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols, result: expr + ubAdjustment); |
592 | } else { |
593 | // TODO: Whenever there are local variables in the dependence |
594 | // constraints, we'll conservatively over-approximate, since we don't |
595 | // always explicitly compute them above (in the while loop). |
596 | if (getNumLocalVars() == 0) { |
597 | // Work on a copy so that we don't update this constraint system. |
598 | if (!tmpClone) { |
599 | tmpClone.emplace(args: FlatLinearConstraints(*this)); |
600 | // Removing redundant inequalities is necessary so that we don't get |
601 | // redundant loop bounds. |
602 | tmpClone->removeRedundantInequalities(); |
603 | } |
604 | std::tie(args&: lbMap, args&: ubMap) = tmpClone->getLowerAndUpperBound( |
605 | pos, offset, num, symStartPos: getNumDimVars(), /*localExprs=*/{}, context, |
606 | closedUB); |
607 | } |
608 | |
609 | // If the above fails, we'll just use the constant lower bound and the |
610 | // constant upper bound (if they exist) as the slice bounds. |
611 | // TODO: being conservative for the moment in cases that |
612 | // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is |
613 | // fixed (b/126426796). |
614 | if (!lbMap || lbMap.getNumResults() > 1) { |
615 | LLVM_DEBUG(llvm::dbgs() |
616 | << "WARNING: Potentially over-approximating slice lb\n" ); |
617 | auto lbConst = getConstantBound64(type: BoundType::LB, pos: pos + offset); |
618 | if (lbConst.has_value()) { |
619 | lbMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols, |
620 | result: getAffineConstantExpr(constant: *lbConst, context)); |
621 | } |
622 | } |
623 | if (!ubMap || ubMap.getNumResults() > 1) { |
624 | LLVM_DEBUG(llvm::dbgs() |
625 | << "WARNING: Potentially over-approximating slice ub\n" ); |
626 | auto ubConst = getConstantBound64(type: BoundType::UB, pos: pos + offset); |
627 | if (ubConst.has_value()) { |
628 | ubMap = AffineMap::get( |
629 | dimCount: numMapDims, symbolCount: numMapSymbols, |
630 | result: getAffineConstantExpr(constant: *ubConst + ubAdjustment, context)); |
631 | } |
632 | } |
633 | } |
634 | LLVM_DEBUG(llvm::dbgs() |
635 | << "lb map for pos = " << Twine(pos + offset) << ", expr: " ); |
636 | LLVM_DEBUG(lbMap.dump();); |
637 | LLVM_DEBUG(llvm::dbgs() |
638 | << "ub map for pos = " << Twine(pos + offset) << ", expr: " ); |
639 | LLVM_DEBUG(ubMap.dump();); |
640 | } |
641 | } |
642 | |
643 | LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals( |
644 | AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) { |
645 | FlatLinearConstraints localCst; |
646 | if (failed(result: getFlattenedAffineExprs(map, flattenedExprs, localVarCst: &localCst))) { |
647 | LLVM_DEBUG(llvm::dbgs() |
648 | << "composition unimplemented for semi-affine maps\n" ); |
649 | return failure(); |
650 | } |
651 | |
652 | // Add localCst information. |
653 | if (localCst.getNumLocalVars() > 0) { |
654 | unsigned numLocalVars = getNumLocalVars(); |
655 | // Insert local dims of localCst at the beginning. |
656 | insertLocalVar(/*pos=*/0, /*num=*/localCst.getNumLocalVars()); |
657 | // Insert local dims of `this` at the end of localCst. |
658 | localCst.appendLocalVar(/*num=*/numLocalVars); |
659 | // Dimensions of localCst and this constraint set match. Append localCst to |
660 | // this constraint set. |
661 | append(other: localCst); |
662 | } |
663 | |
664 | return success(); |
665 | } |
666 | |
667 | LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos, |
668 | AffineMap boundMap, |
669 | bool isClosedBound) { |
670 | assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch" ); |
671 | assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch" ); |
672 | assert(pos < getNumDimAndSymbolVars() && "invalid position" ); |
673 | assert((type != BoundType::EQ || isClosedBound) && |
674 | "EQ bound must be closed." ); |
675 | |
676 | // Equality follows the logic of lower bound except that we add an equality |
677 | // instead of an inequality. |
678 | assert((type != BoundType::EQ || boundMap.getNumResults() == 1) && |
679 | "single result expected" ); |
680 | bool lower = type == BoundType::LB || type == BoundType::EQ; |
681 | |
682 | std::vector<SmallVector<int64_t, 8>> flatExprs; |
683 | if (failed(result: flattenAlignedMapAndMergeLocals(map: boundMap, flattenedExprs: &flatExprs))) |
684 | return failure(); |
685 | assert(flatExprs.size() == boundMap.getNumResults()); |
686 | |
687 | // Add one (in)equality for each result. |
688 | for (const auto &flatExpr : flatExprs) { |
689 | SmallVector<int64_t> ineq(getNumCols(), 0); |
690 | // Dims and symbols. |
691 | for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { |
692 | ineq[j] = lower ? -flatExpr[j] : flatExpr[j]; |
693 | } |
694 | // Invalid bound: pos appears in `boundMap`. |
695 | // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or |
696 | // its callers to prevent invalid bounds from being added. |
697 | if (ineq[pos] != 0) |
698 | continue; |
699 | ineq[pos] = lower ? 1 : -1; |
700 | // Local columns of `ineq` are at the beginning. |
701 | unsigned j = getNumDimVars() + getNumSymbolVars(); |
702 | unsigned end = flatExpr.size() - 1; |
703 | for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) { |
704 | ineq[j] = lower ? -flatExpr[i] : flatExpr[i]; |
705 | } |
706 | // Make the bound closed in if flatExpr is open. The inequality is always |
707 | // created in the upper bound form, so the adjustment is -1. |
708 | int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1; |
709 | // Constant term. |
710 | ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1] |
711 | : flatExpr[flatExpr.size() - 1]) + |
712 | boundAdjustment; |
713 | type == BoundType::EQ ? addEquality(eq: ineq) : addInequality(inEq: ineq); |
714 | } |
715 | |
716 | return success(); |
717 | } |
718 | |
719 | LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos, |
720 | AffineMap boundMap) { |
721 | return addBound(type, pos, boundMap, /*isClosedBound=*/type != BoundType::UB); |
722 | } |
723 | |
724 | /// Compute an explicit representation for local vars. For all systems coming |
725 | /// from MLIR integer sets, maps, or expressions where local vars were |
726 | /// introduced to model floordivs and mods, this always succeeds. |
727 | LogicalResult |
728 | FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo, |
729 | MLIRContext *context) const { |
730 | unsigned numDims = getNumDimVars(); |
731 | unsigned numSyms = getNumSymbolVars(); |
732 | |
733 | // Initialize dimensional and symbolic variables. |
734 | for (unsigned i = 0; i < numDims; i++) |
735 | memo[i] = getAffineDimExpr(position: i, context); |
736 | for (unsigned i = numDims, e = numDims + numSyms; i < e; i++) |
737 | memo[i] = getAffineSymbolExpr(position: i - numDims, context); |
738 | |
739 | bool changed; |
740 | do { |
741 | // Each time `changed` is true at the end of this iteration, one or more |
742 | // local vars would have been detected as floordivs and set in memo; so the |
743 | // number of null entries in memo[...] strictly reduces; so this converges. |
744 | changed = false; |
745 | for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i) |
746 | if (!memo[numDims + numSyms + i] && |
747 | detectAsFloorDiv(cst: *this, /*pos=*/numDims + numSyms + i, context, exprs&: memo)) |
748 | changed = true; |
749 | } while (changed); |
750 | |
751 | ArrayRef<AffineExpr> localExprs = |
752 | ArrayRef<AffineExpr>(memo).take_back(N: getNumLocalVars()); |
753 | return success( |
754 | isSuccess: llvm::all_of(Range&: localExprs, P: [](AffineExpr expr) { return expr; })); |
755 | } |
756 | |
757 | IntegerSet FlatLinearConstraints::getAsIntegerSet(MLIRContext *context) const { |
758 | if (getNumConstraints() == 0) |
759 | // Return universal set (always true): 0 == 0. |
760 | return IntegerSet::get(dimCount: getNumDimVars(), symbolCount: getNumSymbolVars(), |
761 | constraints: getAffineConstantExpr(/*constant=*/0, context), |
762 | /*eqFlags=*/true); |
763 | |
764 | // Construct local references. |
765 | SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr()); |
766 | |
767 | if (failed(result: computeLocalVars(memo, context))) { |
768 | // Check if the local variables without an explicit representation have |
769 | // zero coefficients everywhere. |
770 | SmallVector<unsigned> noLocalRepVars; |
771 | unsigned numDimsSymbols = getNumDimAndSymbolVars(); |
772 | for (unsigned i = numDimsSymbols, e = getNumVars(); i < e; ++i) { |
773 | if (!memo[i] && !isColZero(/*pos=*/i)) |
774 | noLocalRepVars.push_back(Elt: i - numDimsSymbols); |
775 | } |
776 | if (!noLocalRepVars.empty()) { |
777 | LLVM_DEBUG({ |
778 | llvm::dbgs() << "local variables at position(s) " ; |
779 | llvm::interleaveComma(noLocalRepVars, llvm::dbgs()); |
780 | llvm::dbgs() << " do not have an explicit representation in:\n" ; |
781 | this->dump(); |
782 | }); |
783 | return IntegerSet(); |
784 | } |
785 | } |
786 | |
787 | ArrayRef<AffineExpr> localExprs = |
788 | ArrayRef<AffineExpr>(memo).take_back(N: getNumLocalVars()); |
789 | |
790 | // Construct the IntegerSet from the equalities/inequalities. |
791 | unsigned numDims = getNumDimVars(); |
792 | unsigned numSyms = getNumSymbolVars(); |
793 | |
794 | SmallVector<bool, 16> eqFlags(getNumConstraints()); |
795 | std::fill(first: eqFlags.begin(), last: eqFlags.begin() + getNumEqualities(), value: true); |
796 | std::fill(first: eqFlags.begin() + getNumEqualities(), last: eqFlags.end(), value: false); |
797 | |
798 | SmallVector<AffineExpr, 8> exprs; |
799 | exprs.reserve(N: getNumConstraints()); |
800 | |
801 | for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) |
802 | exprs.push_back(Elt: getAffineExprFromFlatForm(flatExprs: getEquality64(idx: i), numDims, |
803 | numSymbols: numSyms, localExprs, context)); |
804 | for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) |
805 | exprs.push_back(Elt: getAffineExprFromFlatForm(flatExprs: getInequality64(idx: i), numDims, |
806 | numSymbols: numSyms, localExprs, context)); |
807 | return IntegerSet::get(dimCount: numDims, symbolCount: numSyms, constraints: exprs, eqFlags); |
808 | } |
809 | |
810 | //===----------------------------------------------------------------------===// |
811 | // FlatLinearValueConstraints |
812 | //===----------------------------------------------------------------------===// |
813 | |
814 | // Construct from an IntegerSet. |
815 | FlatLinearValueConstraints::FlatLinearValueConstraints(IntegerSet set, |
816 | ValueRange operands) |
817 | : FlatLinearConstraints(set.getNumInequalities(), set.getNumEqualities(), |
818 | set.getNumDims() + set.getNumSymbols() + 1, |
819 | set.getNumDims(), set.getNumSymbols(), |
820 | /*numLocals=*/0) { |
821 | assert(operands.empty() || |
822 | set.getNumInputs() == operands.size() && "operand count mismatch" ); |
823 | // Set the values for the non-local variables. |
824 | for (unsigned i = 0, e = operands.size(); i < e; ++i) |
825 | setValue(pos: i, val: operands[i]); |
826 | |
827 | // Flatten expressions and add them to the constraint system. |
828 | std::vector<SmallVector<int64_t, 8>> flatExprs; |
829 | FlatLinearConstraints localVarCst; |
830 | if (failed(result: getFlattenedAffineExprs(set, flattenedExprs: &flatExprs, localVarCst: &localVarCst))) { |
831 | assert(false && "flattening unimplemented for semi-affine integer sets" ); |
832 | return; |
833 | } |
834 | assert(flatExprs.size() == set.getNumConstraints()); |
835 | insertVar(kind: VarKind::Local, pos: getNumVarKind(kind: VarKind::Local), |
836 | /*num=*/localVarCst.getNumLocalVars()); |
837 | |
838 | for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { |
839 | const auto &flatExpr = flatExprs[i]; |
840 | assert(flatExpr.size() == getNumCols()); |
841 | if (set.getEqFlags()[i]) { |
842 | addEquality(eq: flatExpr); |
843 | } else { |
844 | addInequality(inEq: flatExpr); |
845 | } |
846 | } |
847 | // Add the other constraints involving local vars from flattening. |
848 | append(other: localVarCst); |
849 | } |
850 | |
851 | unsigned FlatLinearValueConstraints::appendDimVar(ValueRange vals) { |
852 | unsigned pos = getNumDimVars(); |
853 | return insertVar(kind: VarKind::SetDim, pos, vals); |
854 | } |
855 | |
856 | unsigned FlatLinearValueConstraints::appendSymbolVar(ValueRange vals) { |
857 | unsigned pos = getNumSymbolVars(); |
858 | return insertVar(kind: VarKind::Symbol, pos, vals); |
859 | } |
860 | |
861 | unsigned FlatLinearValueConstraints::insertDimVar(unsigned pos, |
862 | ValueRange vals) { |
863 | return insertVar(kind: VarKind::SetDim, pos, vals); |
864 | } |
865 | |
866 | unsigned FlatLinearValueConstraints::insertSymbolVar(unsigned pos, |
867 | ValueRange vals) { |
868 | return insertVar(kind: VarKind::Symbol, pos, vals); |
869 | } |
870 | |
871 | unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos, |
872 | unsigned num) { |
873 | unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num); |
874 | |
875 | return absolutePos; |
876 | } |
877 | |
878 | unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos, |
879 | ValueRange vals) { |
880 | assert(!vals.empty() && "expected ValueRange with Values." ); |
881 | assert(kind != VarKind::Local && |
882 | "values cannot be attached to local variables." ); |
883 | unsigned num = vals.size(); |
884 | unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num); |
885 | |
886 | // If a Value is provided, insert it; otherwise use std::nullopt. |
887 | for (unsigned i = 0, e = vals.size(); i < e; ++i) |
888 | if (vals[i]) |
889 | setValue(pos: absolutePos + i, val: vals[i]); |
890 | |
891 | return absolutePos; |
892 | } |
893 | |
894 | /// Checks if two constraint systems are in the same space, i.e., if they are |
895 | /// associated with the same set of variables, appearing in the same order. |
896 | static bool areVarsAligned(const FlatLinearValueConstraints &a, |
897 | const FlatLinearValueConstraints &b) { |
898 | if (a.getNumDomainVars() != b.getNumDomainVars() || |
899 | a.getNumRangeVars() != b.getNumRangeVars() || |
900 | a.getNumSymbolVars() != b.getNumSymbolVars()) |
901 | return false; |
902 | SmallVector<std::optional<Value>> aMaybeValues = a.getMaybeValues(), |
903 | bMaybeValues = b.getMaybeValues(); |
904 | return std::equal(first1: aMaybeValues.begin(), last1: aMaybeValues.end(), |
905 | first2: bMaybeValues.begin(), last2: bMaybeValues.end()); |
906 | } |
907 | |
908 | /// Calls areVarsAligned to check if two constraint systems have the same set |
909 | /// of variables in the same order. |
910 | bool FlatLinearValueConstraints::areVarsAlignedWithOther( |
911 | const FlatLinearConstraints &other) { |
912 | return areVarsAligned(a: *this, b: other); |
913 | } |
914 | |
915 | /// Checks if the SSA values associated with `cst`'s variables in range |
916 | /// [start, end) are unique. |
917 | static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( |
918 | const FlatLinearValueConstraints &cst, unsigned start, unsigned end) { |
919 | |
920 | assert(start <= cst.getNumDimAndSymbolVars() && |
921 | "Start position out of bounds" ); |
922 | assert(end <= cst.getNumDimAndSymbolVars() && "End position out of bounds" ); |
923 | |
924 | if (start >= end) |
925 | return true; |
926 | |
927 | SmallPtrSet<Value, 8> uniqueVars; |
928 | SmallVector<std::optional<Value>, 8> maybeValuesAll = cst.getMaybeValues(); |
929 | ArrayRef<std::optional<Value>> maybeValues = {maybeValuesAll.data() + start, |
930 | maybeValuesAll.data() + end}; |
931 | |
932 | for (std::optional<Value> val : maybeValues) |
933 | if (val && !uniqueVars.insert(Ptr: *val).second) |
934 | return false; |
935 | |
936 | return true; |
937 | } |
938 | |
939 | /// Checks if the SSA values associated with `cst`'s variables are unique. |
940 | static bool LLVM_ATTRIBUTE_UNUSED |
941 | areVarsUnique(const FlatLinearValueConstraints &cst) { |
942 | return areVarsUnique(cst, start: 0, end: cst.getNumDimAndSymbolVars()); |
943 | } |
944 | |
945 | /// Checks if the SSA values associated with `cst`'s variables of kind `kind` |
946 | /// are unique. |
947 | static bool LLVM_ATTRIBUTE_UNUSED |
948 | areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) { |
949 | |
950 | if (kind == VarKind::SetDim) |
951 | return areVarsUnique(cst, start: 0, end: cst.getNumDimVars()); |
952 | if (kind == VarKind::Symbol) |
953 | return areVarsUnique(cst, start: cst.getNumDimVars(), |
954 | end: cst.getNumDimAndSymbolVars()); |
955 | llvm_unreachable("Unexpected VarKind" ); |
956 | } |
957 | |
958 | /// Merge and align the variables of A and B starting at 'offset', so that |
959 | /// both constraint systems get the union of the contained variables that is |
960 | /// dimension-wise and symbol-wise unique; both constraint systems are updated |
961 | /// so that they have the union of all variables, with A's original |
962 | /// variables appearing first followed by any of B's variables that didn't |
963 | /// appear in A. Local variables in B that have the same division |
964 | /// representation as local variables in A are merged into one. We allow A |
965 | /// and B to have non-unique values for their variables; in such cases, they are |
966 | /// still aligned with the variables appearing first aligned with those |
967 | /// appearing first in the other system from left to right. |
968 | // E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M]) |
969 | // Output: both A, B have (%i, %j, %k) [%M, %N, %P] |
970 | static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a, |
971 | FlatLinearValueConstraints *b) { |
972 | assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars()); |
973 | |
974 | assert(llvm::all_of( |
975 | llvm::drop_begin(a->getMaybeValues(), offset), |
976 | [](const std::optional<Value> &var) { return var.has_value(); })); |
977 | |
978 | assert(llvm::all_of( |
979 | llvm::drop_begin(b->getMaybeValues(), offset), |
980 | [](const std::optional<Value> &var) { return var.has_value(); })); |
981 | |
982 | SmallVector<Value, 4> aDimValues; |
983 | a->getValues(start: offset, end: a->getNumDimVars(), values: &aDimValues); |
984 | |
985 | { |
986 | // Merge dims from A into B. |
987 | unsigned d = offset; |
988 | for (Value aDimValue : aDimValues) { |
989 | unsigned loc; |
990 | // Find from the position `d` since we'd like to also consider the |
991 | // possibility of multiple variables with the same `Value`. We align with |
992 | // the next appearing one. |
993 | if (b->findVar(val: aDimValue, pos: &loc, offset: d)) { |
994 | assert(loc >= offset && "A's dim appears in B's aligned range" ); |
995 | assert(loc < b->getNumDimVars() && |
996 | "A's dim appears in B's non-dim position" ); |
997 | b->swapVar(posA: d, posB: loc); |
998 | } else { |
999 | b->insertDimVar(pos: d, vals: aDimValue); |
1000 | } |
1001 | d++; |
1002 | } |
1003 | // Dimensions that are in B, but not in A, are added at the end. |
1004 | for (unsigned t = a->getNumDimVars(), e = b->getNumDimVars(); t < e; t++) { |
1005 | a->appendDimVar(vals: b->getValue(pos: t)); |
1006 | } |
1007 | assert(a->getNumDimVars() == b->getNumDimVars() && |
1008 | "expected same number of dims" ); |
1009 | } |
1010 | |
1011 | // Merge and align symbols of A and B |
1012 | a->mergeSymbolVars(other&: *b); |
1013 | // Merge and align locals of A and B |
1014 | a->mergeLocalVars(other&: *b); |
1015 | |
1016 | assert(areVarsAligned(*a, *b) && "IDs expected to be aligned" ); |
1017 | } |
1018 | |
1019 | // Call 'mergeAndAlignVars' to align constraint systems of 'this' and 'other'. |
1020 | void FlatLinearValueConstraints::mergeAndAlignVarsWithOther( |
1021 | unsigned offset, FlatLinearValueConstraints *other) { |
1022 | mergeAndAlignVars(offset, a: this, b: other); |
1023 | } |
1024 | |
1025 | /// Merge and align symbols of `this` and `other` such that both get union of |
1026 | /// of symbols. Existing symbols need not be unique; they will be aligned from |
1027 | /// left to right with duplicates aligned in the same order. Symbols with Value |
1028 | /// as `None` are considered to be inequal to all other symbols. |
1029 | void FlatLinearValueConstraints::mergeSymbolVars( |
1030 | FlatLinearValueConstraints &other) { |
1031 | |
1032 | SmallVector<Value, 4> aSymValues; |
1033 | getValues(start: getNumDimVars(), end: getNumDimAndSymbolVars(), values: &aSymValues); |
1034 | |
1035 | // Merge symbols: merge symbols into `other` first from `this`. |
1036 | unsigned s = other.getNumDimVars(); |
1037 | for (Value aSymValue : aSymValues) { |
1038 | unsigned loc; |
1039 | // If the var is a symbol in `other`, then align it, otherwise assume that |
1040 | // it is a new symbol. Search in `other` starting at position `s` since the |
1041 | // left of it is aligned. |
1042 | if (other.findVar(val: aSymValue, pos: &loc, offset: s) && loc >= other.getNumDimVars() && |
1043 | loc < other.getNumDimAndSymbolVars()) |
1044 | other.swapVar(posA: s, posB: loc); |
1045 | else |
1046 | other.insertSymbolVar(pos: s - other.getNumDimVars(), vals: aSymValue); |
1047 | s++; |
1048 | } |
1049 | |
1050 | // Symbols that are in other, but not in this, are added at the end. |
1051 | for (unsigned t = other.getNumDimVars() + getNumSymbolVars(), |
1052 | e = other.getNumDimAndSymbolVars(); |
1053 | t < e; t++) |
1054 | insertSymbolVar(pos: getNumSymbolVars(), vals: other.getValue(pos: t)); |
1055 | |
1056 | assert(getNumSymbolVars() == other.getNumSymbolVars() && |
1057 | "expected same number of symbols" ); |
1058 | } |
1059 | |
1060 | void FlatLinearValueConstraints::removeVarRange(VarKind kind, unsigned varStart, |
1061 | unsigned varLimit) { |
1062 | IntegerPolyhedron::removeVarRange(kind, varStart, varLimit); |
1063 | } |
1064 | |
1065 | AffineMap |
1066 | FlatLinearValueConstraints::computeAlignedMap(AffineMap map, |
1067 | ValueRange operands) const { |
1068 | assert(map.getNumInputs() == operands.size() && "number of inputs mismatch" ); |
1069 | |
1070 | SmallVector<Value> dims, syms; |
1071 | #ifndef NDEBUG |
1072 | SmallVector<Value> newSyms; |
1073 | SmallVector<Value> *newSymsPtr = &newSyms; |
1074 | #else |
1075 | SmallVector<Value> *newSymsPtr = nullptr; |
1076 | #endif // NDEBUG |
1077 | |
1078 | dims.reserve(N: getNumDimVars()); |
1079 | syms.reserve(N: getNumSymbolVars()); |
1080 | for (unsigned i = 0, e = getNumVarKind(kind: VarKind::SetDim); i < e; ++i) { |
1081 | Identifier id = space.getId(kind: VarKind::SetDim, pos: i); |
1082 | dims.push_back(Elt: id.hasValue() ? Value(id.getValue<Value>()) : Value()); |
1083 | } |
1084 | for (unsigned i = 0, e = getNumVarKind(kind: VarKind::Symbol); i < e; ++i) { |
1085 | Identifier id = space.getId(kind: VarKind::Symbol, pos: i); |
1086 | syms.push_back(Elt: id.hasValue() ? Value(id.getValue<Value>()) : Value()); |
1087 | } |
1088 | |
1089 | AffineMap alignedMap = |
1090 | alignAffineMapWithValues(map, operands, dims, syms, newSyms: newSymsPtr); |
1091 | // All symbols are already part of this FlatAffineValueConstraints. |
1092 | assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols" ); |
1093 | assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) && |
1094 | "unexpected new/missing symbols" ); |
1095 | return alignedMap; |
1096 | } |
1097 | |
1098 | bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos, |
1099 | unsigned offset) const { |
1100 | SmallVector<std::optional<Value>> maybeValues = getMaybeValues(); |
1101 | for (unsigned i = offset, e = maybeValues.size(); i < e; ++i) |
1102 | if (maybeValues[i] && maybeValues[i].value() == val) { |
1103 | *pos = i; |
1104 | return true; |
1105 | } |
1106 | return false; |
1107 | } |
1108 | |
1109 | bool FlatLinearValueConstraints::containsVar(Value val) const { |
1110 | unsigned pos; |
1111 | return findVar(val, pos: &pos, offset: 0); |
1112 | } |
1113 | |
1114 | void FlatLinearValueConstraints::addBound(BoundType type, Value val, |
1115 | int64_t value) { |
1116 | unsigned pos; |
1117 | if (!findVar(val, pos: &pos)) |
1118 | // This is a pre-condition for this method. |
1119 | assert(0 && "var not found" ); |
1120 | addBound(type, pos, value); |
1121 | } |
1122 | |
1123 | void FlatLinearConstraints::printSpace(raw_ostream &os) const { |
1124 | IntegerPolyhedron::printSpace(os); |
1125 | os << "(" ; |
1126 | for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) |
1127 | os << "None\t" ; |
1128 | for (unsigned i = getVarKindOffset(kind: VarKind::Local), |
1129 | e = getVarKindEnd(kind: VarKind::Local); |
1130 | i < e; ++i) |
1131 | os << "Local\t" ; |
1132 | os << "const)\n" ; |
1133 | } |
1134 | |
1135 | void FlatLinearValueConstraints::printSpace(raw_ostream &os) const { |
1136 | IntegerPolyhedron::printSpace(os); |
1137 | os << "(" ; |
1138 | for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) { |
1139 | if (hasValue(pos: i)) |
1140 | os << "Value\t" ; |
1141 | else |
1142 | os << "None\t" ; |
1143 | } |
1144 | for (unsigned i = getVarKindOffset(kind: VarKind::Local), |
1145 | e = getVarKindEnd(kind: VarKind::Local); |
1146 | i < e; ++i) |
1147 | os << "Local\t" ; |
1148 | os << "const)\n" ; |
1149 | } |
1150 | |
1151 | void FlatLinearValueConstraints::projectOut(Value val) { |
1152 | unsigned pos; |
1153 | bool ret = findVar(val, pos: &pos); |
1154 | assert(ret); |
1155 | (void)ret; |
1156 | fourierMotzkinEliminate(pos); |
1157 | } |
1158 | |
1159 | LogicalResult FlatLinearValueConstraints::unionBoundingBox( |
1160 | const FlatLinearValueConstraints &otherCst) { |
1161 | assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch" ); |
1162 | SmallVector<std::optional<Value>> maybeValues = getMaybeValues(), |
1163 | otherMaybeValues = |
1164 | otherCst.getMaybeValues(); |
1165 | assert(std::equal(maybeValues.begin(), maybeValues.begin() + getNumDimVars(), |
1166 | otherMaybeValues.begin(), |
1167 | otherMaybeValues.begin() + getNumDimVars()) && |
1168 | "dim values mismatch" ); |
1169 | assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here" ); |
1170 | assert(getNumLocalVars() == 0 && "local vars not supported yet here" ); |
1171 | |
1172 | // Align `other` to this. |
1173 | if (!areVarsAligned(a: *this, b: otherCst)) { |
1174 | FlatLinearValueConstraints otherCopy(otherCst); |
1175 | mergeAndAlignVars(/*offset=*/getNumDimVars(), a: this, b: &otherCopy); |
1176 | return IntegerPolyhedron::unionBoundingBox(other: otherCopy); |
1177 | } |
1178 | |
1179 | return IntegerPolyhedron::unionBoundingBox(other: otherCst); |
1180 | } |
1181 | |
1182 | //===----------------------------------------------------------------------===// |
1183 | // Helper functions |
1184 | //===----------------------------------------------------------------------===// |
1185 | |
1186 | AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands, |
1187 | ValueRange dims, ValueRange syms, |
1188 | SmallVector<Value> *newSyms) { |
1189 | assert(operands.size() == map.getNumInputs() && |
1190 | "expected same number of operands and map inputs" ); |
1191 | MLIRContext *ctx = map.getContext(); |
1192 | Builder builder(ctx); |
1193 | SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {}); |
1194 | unsigned numSymbols = syms.size(); |
1195 | SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {}); |
1196 | if (newSyms) { |
1197 | newSyms->clear(); |
1198 | newSyms->append(in_start: syms.begin(), in_end: syms.end()); |
1199 | } |
1200 | |
1201 | for (const auto &operand : llvm::enumerate(First&: operands)) { |
1202 | // Compute replacement dim/sym of operand. |
1203 | AffineExpr replacement; |
1204 | auto dimIt = llvm::find(Range&: dims, Val: operand.value()); |
1205 | auto symIt = llvm::find(Range&: syms, Val: operand.value()); |
1206 | if (dimIt != dims.end()) { |
1207 | replacement = |
1208 | builder.getAffineDimExpr(position: std::distance(first: dims.begin(), last: dimIt)); |
1209 | } else if (symIt != syms.end()) { |
1210 | replacement = |
1211 | builder.getAffineSymbolExpr(position: std::distance(first: syms.begin(), last: symIt)); |
1212 | } else { |
1213 | // This operand is neither a dimension nor a symbol. Add it as a new |
1214 | // symbol. |
1215 | replacement = builder.getAffineSymbolExpr(position: numSymbols++); |
1216 | if (newSyms) |
1217 | newSyms->push_back(Elt: operand.value()); |
1218 | } |
1219 | // Add to corresponding replacements vector. |
1220 | if (operand.index() < map.getNumDims()) { |
1221 | dimReplacements[operand.index()] = replacement; |
1222 | } else { |
1223 | symReplacements[operand.index() - map.getNumDims()] = replacement; |
1224 | } |
1225 | } |
1226 | |
1227 | return map.replaceDimsAndSymbols(dimReplacements, symReplacements, |
1228 | numResultDims: dims.size(), numResultSyms: numSymbols); |
1229 | } |
1230 | |
1231 | LogicalResult |
1232 | mlir::getMultiAffineFunctionFromMap(AffineMap map, |
1233 | MultiAffineFunction &multiAff) { |
1234 | FlatLinearConstraints cst; |
1235 | std::vector<SmallVector<int64_t, 8>> flattenedExprs; |
1236 | LogicalResult result = getFlattenedAffineExprs(map, flattenedExprs: &flattenedExprs, localVarCst: &cst); |
1237 | |
1238 | if (result.failed()) |
1239 | return failure(); |
1240 | |
1241 | DivisionRepr divs = cst.getLocalReprs(); |
1242 | assert(divs.hasAllReprs() && |
1243 | "AffineMap cannot produce divs without local representation" ); |
1244 | |
1245 | // TODO: We shouldn't have to do this conversion. |
1246 | Matrix<MPInt> mat(map.getNumResults(), map.getNumInputs() + divs.getNumDivs() + 1); |
1247 | for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i) |
1248 | for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j) |
1249 | mat(i, j) = flattenedExprs[i][j]; |
1250 | |
1251 | multiAff = MultiAffineFunction( |
1252 | PresburgerSpace::getRelationSpace(numDomain: map.getNumDims(), numRange: map.getNumResults(), |
1253 | numSymbols: map.getNumSymbols(), numLocals: divs.getNumDivs()), |
1254 | mat, divs); |
1255 | |
1256 | return success(); |
1257 | } |
1258 | |