1 | //===- FlatLinearValueConstraints.cpp - Linear Constraint -----------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis//FlatLinearValueConstraints.h" |
10 | |
11 | #include "mlir/Analysis/Presburger/LinearTransform.h" |
12 | #include "mlir/Analysis/Presburger/PresburgerSpace.h" |
13 | #include "mlir/Analysis/Presburger/Simplex.h" |
14 | #include "mlir/Analysis/Presburger/Utils.h" |
15 | #include "mlir/IR/AffineExprVisitor.h" |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/IntegerSet.h" |
18 | #include "mlir/Support/LLVM.h" |
19 | #include "llvm/ADT/STLExtras.h" |
20 | #include "llvm/ADT/SmallPtrSet.h" |
21 | #include "llvm/ADT/SmallVector.h" |
22 | #include "llvm/Support/Debug.h" |
23 | #include "llvm/Support/InterleavedRange.h" |
24 | #include "llvm/Support/raw_ostream.h" |
25 | #include <optional> |
26 | |
27 | #define DEBUG_TYPE "flat-value-constraints" |
28 | |
29 | using namespace mlir; |
30 | using namespace presburger; |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // AffineExprFlattener |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | namespace { |
37 | |
38 | // See comments for SimpleAffineExprFlattener. |
39 | // An AffineExprFlattenerWithLocalVars extends a SimpleAffineExprFlattener by |
40 | // recording constraint information associated with mod's, floordiv's, and |
41 | // ceildiv's in FlatLinearConstraints 'localVarCst'. |
42 | struct AffineExprFlattener : public SimpleAffineExprFlattener { |
43 | using SimpleAffineExprFlattener::SimpleAffineExprFlattener; |
44 | |
45 | // Constraints connecting newly introduced local variables (for mod's and |
46 | // div's) to existing (dimensional and symbolic) ones. These are always |
47 | // inequalities. |
48 | IntegerPolyhedron localVarCst; |
49 | |
50 | AffineExprFlattener(unsigned nDims, unsigned nSymbols) |
51 | : SimpleAffineExprFlattener(nDims, nSymbols), |
52 | localVarCst(PresburgerSpace::getSetSpace(numDims: nDims, numSymbols: nSymbols)) {}; |
53 | |
54 | private: |
55 | // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr). |
56 | // The local variable added is always a floordiv of a pure add/mul affine |
57 | // function of other variables, coefficients of which are specified in |
58 | // `dividend' and with respect to the positive constant `divisor'. localExpr |
59 | // is the simplified tree expression (AffineExpr) corresponding to the |
60 | // quantifier. |
61 | void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor, |
62 | AffineExpr localExpr) override { |
63 | SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr); |
64 | // Update localVarCst. |
65 | localVarCst.addLocalFloorDiv(dividend, divisor); |
66 | } |
67 | |
68 | LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs, |
69 | ArrayRef<int64_t> rhs, |
70 | AffineExpr localExpr) override { |
71 | // AffineExprFlattener does not support semi-affine expressions. |
72 | return failure(); |
73 | } |
74 | }; |
75 | |
76 | // A SemiAffineExprFlattener is an AffineExprFlattenerWithLocalVars that adds |
77 | // conservative bounds for semi-affine expressions (given assumptions hold). If |
78 | // the assumptions required to add the semi-affine bounds are found not to hold |
79 | // the final constraints set will be empty/inconsistent. If the assumptions are |
80 | // never contradicted the final bounds still only will be correct if the |
81 | // assumptions hold. |
82 | struct SemiAffineExprFlattener : public AffineExprFlattener { |
83 | using AffineExprFlattener::AffineExprFlattener; |
84 | |
85 | LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs, |
86 | ArrayRef<int64_t> rhs, |
87 | AffineExpr localExpr) override { |
88 | auto result = |
89 | SimpleAffineExprFlattener::addLocalIdSemiAffine(lhs, rhs, localExpr); |
90 | assert(succeeded(result) && |
91 | "unexpected failure in SimpleAffineExprFlattener"); |
92 | (void)result; |
93 | |
94 | if (localExpr.getKind() == AffineExprKind::Mod) { |
95 | // Given two numbers a and b, division is defined as: |
96 | // |
97 | // a = bq + r |
98 | // 0 <= r < |b| (where |x| is the absolute value of x) |
99 | // |
100 | // q = a floordiv b |
101 | // r = a mod b |
102 | |
103 | // Add a new local variable (r) to represent the mod. |
104 | unsigned rPos = localVarCst.appendVar(kind: VarKind::Local); |
105 | |
106 | // r >= 0 (Can ALWAYS be added) |
107 | localVarCst.addBound(type: BoundType::LB, pos: rPos, value: 0); |
108 | |
109 | // r < b (Can be added if b > 0, which we assume here) |
110 | ArrayRef<int64_t> b = rhs; |
111 | SmallVector<int64_t> bSubR(b); |
112 | bSubR.insert(I: bSubR.begin() + rPos, Elt: -1); |
113 | // Note: bSubR = b - r |
114 | // So this adds the bound b - r >= 1 (equivalent to r < b) |
115 | localVarCst.addBound(type: BoundType::LB, expr: bSubR, value: 1); |
116 | |
117 | // Note: The assumption of b > 0 is based on the affine expression docs, |
118 | // which state "RHS of mod is always a constant or a symbolic expression |
119 | // with a positive value." (see AffineExprKind in AffineExpr.h). If this |
120 | // assumption does not hold constraints (added above) are a contradiction. |
121 | |
122 | return success(); |
123 | } |
124 | |
125 | // TODO: Support other semi-affine expressions. |
126 | return failure(); |
127 | } |
128 | }; |
129 | |
130 | } // namespace |
131 | |
132 | // Flattens the expressions in map. Returns failure if 'expr' was unable to be |
133 | // flattened. For example two specific cases: |
134 | // 1. an unhandled semi-affine expressions is found. |
135 | // 2. has poison expression (i.e., division by zero). |
136 | static LogicalResult |
137 | getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims, |
138 | unsigned numSymbols, |
139 | std::vector<SmallVector<int64_t, 8>> *flattenedExprs, |
140 | FlatLinearConstraints *localVarCst, |
141 | bool addConservativeSemiAffineBounds = false) { |
142 | if (exprs.empty()) { |
143 | if (localVarCst) |
144 | *localVarCst = FlatLinearConstraints(numDims, numSymbols); |
145 | return success(); |
146 | } |
147 | |
148 | auto flattenExprs = [&](AffineExprFlattener &flattener) -> LogicalResult { |
149 | // Use the same flattener to simplify each expression successively. This way |
150 | // local variables / expressions are shared. |
151 | for (auto expr : exprs) { |
152 | auto flattenResult = flattener.walkPostOrder(expr); |
153 | if (failed(Result: flattenResult)) |
154 | return failure(); |
155 | } |
156 | |
157 | assert(flattener.operandExprStack.size() == exprs.size()); |
158 | flattenedExprs->clear(); |
159 | flattenedExprs->assign(first: flattener.operandExprStack.begin(), |
160 | last: flattener.operandExprStack.end()); |
161 | |
162 | if (localVarCst) |
163 | localVarCst->clearAndCopyFrom(other: flattener.localVarCst); |
164 | |
165 | return success(); |
166 | }; |
167 | |
168 | if (addConservativeSemiAffineBounds) { |
169 | SemiAffineExprFlattener flattener(numDims, numSymbols); |
170 | return flattenExprs(flattener); |
171 | } |
172 | |
173 | AffineExprFlattener flattener(numDims, numSymbols); |
174 | return flattenExprs(flattener); |
175 | } |
176 | |
177 | // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to |
178 | // be flattened (an unhandled semi-affine was found). |
179 | LogicalResult mlir::getFlattenedAffineExpr( |
180 | AffineExpr expr, unsigned numDims, unsigned numSymbols, |
181 | SmallVectorImpl<int64_t> *flattenedExpr, FlatLinearConstraints *localVarCst, |
182 | bool addConservativeSemiAffineBounds) { |
183 | std::vector<SmallVector<int64_t, 8>> flattenedExprs; |
184 | LogicalResult ret = |
185 | ::getFlattenedAffineExprs(exprs: {expr}, numDims, numSymbols, flattenedExprs: &flattenedExprs, |
186 | localVarCst, addConservativeSemiAffineBounds); |
187 | *flattenedExpr = flattenedExprs[0]; |
188 | return ret; |
189 | } |
190 | |
191 | /// Flattens the expressions in map. Returns failure if 'expr' was unable to be |
192 | /// flattened (i.e., an unhandled semi-affine was found). |
193 | LogicalResult mlir::getFlattenedAffineExprs( |
194 | AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, |
195 | FlatLinearConstraints *localVarCst, bool addConservativeSemiAffineBounds) { |
196 | if (map.getNumResults() == 0) { |
197 | if (localVarCst) |
198 | *localVarCst = |
199 | FlatLinearConstraints(map.getNumDims(), map.getNumSymbols()); |
200 | return success(); |
201 | } |
202 | return ::getFlattenedAffineExprs( |
203 | exprs: map.getResults(), numDims: map.getNumDims(), numSymbols: map.getNumSymbols(), flattenedExprs, |
204 | localVarCst, addConservativeSemiAffineBounds); |
205 | } |
206 | |
207 | LogicalResult mlir::getFlattenedAffineExprs( |
208 | IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, |
209 | FlatLinearConstraints *localVarCst) { |
210 | if (set.getNumConstraints() == 0) { |
211 | if (localVarCst) |
212 | *localVarCst = |
213 | FlatLinearConstraints(set.getNumDims(), set.getNumSymbols()); |
214 | return success(); |
215 | } |
216 | return ::getFlattenedAffineExprs(exprs: set.getConstraints(), numDims: set.getNumDims(), |
217 | numSymbols: set.getNumSymbols(), flattenedExprs, |
218 | localVarCst); |
219 | } |
220 | |
221 | //===----------------------------------------------------------------------===// |
222 | // FlatLinearConstraints |
223 | //===----------------------------------------------------------------------===// |
224 | |
225 | // Similar to `composeMap` except that no Values need be associated with the |
226 | // constraint system nor are they looked at -- the dimensions and symbols of |
227 | // `other` are expected to correspond 1:1 to `this` system. |
228 | LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) { |
229 | assert(other.getNumDims() == getNumDimVars() && "dim mismatch"); |
230 | assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch"); |
231 | |
232 | std::vector<SmallVector<int64_t, 8>> flatExprs; |
233 | if (failed(Result: flattenAlignedMapAndMergeLocals(map: other, flattenedExprs: &flatExprs))) |
234 | return failure(); |
235 | assert(flatExprs.size() == other.getNumResults()); |
236 | |
237 | // Add dimensions corresponding to the map's results. |
238 | insertDimVar(/*pos=*/0, /*num=*/other.getNumResults()); |
239 | |
240 | // We add one equality for each result connecting the result dim of the map to |
241 | // the other variables. |
242 | // E.g.: if the expression is 16*i0 + i1, and this is the r^th |
243 | // iteration/result of the value map, we are adding the equality: |
244 | // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we |
245 | // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. |
246 | for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { |
247 | const auto &flatExpr = flatExprs[r]; |
248 | assert(flatExpr.size() >= other.getNumInputs() + 1); |
249 | |
250 | SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0); |
251 | // Set the coefficient for this result to one. |
252 | eqToAdd[r] = 1; |
253 | |
254 | // Dims and symbols. |
255 | for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) { |
256 | // Negate `eq[r]` since the newly added dimension will be set to this one. |
257 | eqToAdd[e + i] = -flatExpr[i]; |
258 | } |
259 | // Local columns of `eq` are at the beginning. |
260 | unsigned j = getNumDimVars() + getNumSymbolVars(); |
261 | unsigned end = flatExpr.size() - 1; |
262 | for (unsigned i = other.getNumInputs(); i < end; i++, j++) { |
263 | eqToAdd[j] = -flatExpr[i]; |
264 | } |
265 | |
266 | // Constant term. |
267 | eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; |
268 | |
269 | // Add the equality connecting the result of the map to this constraint set. |
270 | addEquality(eq: eqToAdd); |
271 | } |
272 | |
273 | return success(); |
274 | } |
275 | |
276 | // Determine whether the variable at 'pos' (say var_r) can be expressed as |
277 | // modulo of another known variable (say var_n) w.r.t a constant. For example, |
278 | // if the following constraints hold true: |
279 | // ``` |
280 | // 0 <= var_r <= divisor - 1 |
281 | // var_n - (divisor * q_expr) = var_r |
282 | // ``` |
283 | // where `var_n` is a known variable (called dividend), and `q_expr` is an |
284 | // `AffineExpr` (called the quotient expression), `var_r` can be written as: |
285 | // |
286 | // `var_r = var_n mod divisor`. |
287 | // |
288 | // Additionally, in a special case of the above constaints where `q_expr` is an |
289 | // variable itself that is not yet known (say `var_q`), it can be written as a |
290 | // floordiv in the following way: |
291 | // |
292 | // `var_q = var_n floordiv divisor`. |
293 | // |
294 | // First 'num' dimensional variables starting at 'offset' are |
295 | // derived/to-be-derived in terms of the remaining variables. The remaining |
296 | // variables are assigned trivial affine expressions in `memo`. For example, |
297 | // memo is initilized as follows for a `cst` with 5 dims, when offset=2, num=2: |
298 | // memo ==> d0 d1 . . d2 ... |
299 | // cst ==> c0 c1 c2 c3 c4 ... |
300 | // |
301 | // Returns true if the above mod or floordiv are detected, updating 'memo' with |
302 | // these new expressions. Returns false otherwise. |
303 | static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos, |
304 | unsigned offset, unsigned num, int64_t lbConst, |
305 | int64_t ubConst, MLIRContext *context, |
306 | SmallVectorImpl<AffineExpr> &memo) { |
307 | assert(pos < cst.getNumVars() && "invalid position"); |
308 | |
309 | // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can |
310 | // be determined. |
311 | if (lbConst != 0 || ubConst < 1) |
312 | return false; |
313 | int64_t divisor = ubConst + 1; |
314 | |
315 | // Check for the aforementioned conditions in each equality. |
316 | for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities(); |
317 | curEquality < numEqualities; curEquality++) { |
318 | int64_t coefficientAtPos = cst.atEq64(i: curEquality, j: pos); |
319 | // If current equality does not involve `var_r`, continue to the next |
320 | // equality. |
321 | if (coefficientAtPos == 0) |
322 | continue; |
323 | |
324 | // Constant term should be 0 in this equality. |
325 | if (cst.atEq64(i: curEquality, j: cst.getNumCols() - 1) != 0) |
326 | continue; |
327 | |
328 | // Traverse through the equality and construct the dividend expression |
329 | // `dividendExpr`, to contain all the variables which are known and are |
330 | // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the |
331 | // `dividendExpr` gets simplified into a single variable `var_n` discussed |
332 | // above. |
333 | auto dividendExpr = getAffineConstantExpr(constant: 0, context); |
334 | |
335 | // Track the terms that go into quotient expression, later used to detect |
336 | // additional floordiv. |
337 | unsigned quotientCount = 0; |
338 | int quotientPosition = -1; |
339 | int quotientSign = 1; |
340 | |
341 | // Consider each term in the current equality. |
342 | unsigned curVar, e; |
343 | for (curVar = 0, e = cst.getNumDimAndSymbolVars(); curVar < e; ++curVar) { |
344 | // Ignore var_r. |
345 | if (curVar == pos) |
346 | continue; |
347 | int64_t coefficientOfCurVar = cst.atEq64(i: curEquality, j: curVar); |
348 | // Ignore vars that do not contribute to the current equality. |
349 | if (coefficientOfCurVar == 0) |
350 | continue; |
351 | // Check if the current var goes into the quotient expression. |
352 | if (coefficientOfCurVar % (divisor * coefficientAtPos) == 0) { |
353 | quotientCount++; |
354 | quotientPosition = curVar; |
355 | quotientSign = (coefficientOfCurVar * coefficientAtPos) > 0 ? 1 : -1; |
356 | continue; |
357 | } |
358 | // Variables that are part of dividendExpr should be known. |
359 | if (!memo[curVar]) |
360 | break; |
361 | // Append the current variable to the dividend expression. |
362 | dividendExpr = dividendExpr + memo[curVar] * coefficientOfCurVar; |
363 | } |
364 | |
365 | // Can't construct expression as it depends on a yet uncomputed var. |
366 | if (curVar < e) |
367 | continue; |
368 | |
369 | // Express `var_r` in terms of the other vars collected so far. |
370 | if (coefficientAtPos > 0) |
371 | dividendExpr = (-dividendExpr).floorDiv(v: coefficientAtPos); |
372 | else |
373 | dividendExpr = dividendExpr.floorDiv(v: -coefficientAtPos); |
374 | |
375 | // Simplify the expression. |
376 | dividendExpr = simplifyAffineExpr(expr: dividendExpr, numDims: cst.getNumDimVars(), |
377 | numSymbols: cst.getNumSymbolVars()); |
378 | // Only if the final dividend expression is just a single var (which we call |
379 | // `var_n`), we can proceed. |
380 | // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it |
381 | // to dims themselves. |
382 | auto dimExpr = dyn_cast<AffineDimExpr>(Val&: dividendExpr); |
383 | if (!dimExpr) |
384 | continue; |
385 | |
386 | // Express `var_r` as `var_n % divisor` and store the expression in `memo`. |
387 | if (quotientCount >= 1) { |
388 | // Find the column corresponding to `dimExpr`. `num` columns starting at |
389 | // `offset` correspond to previously unknown variables. The column |
390 | // corresponding to the trivially known `dimExpr` can be on either side |
391 | // of these. |
392 | unsigned dimExprPos = dimExpr.getPosition(); |
393 | unsigned dimExprCol = dimExprPos < offset ? dimExprPos : dimExprPos + num; |
394 | auto ub = cst.getConstantBound64(type: BoundType::UB, pos: dimExprCol); |
395 | // If `var_n` has an upperbound that is less than the divisor, mod can be |
396 | // eliminated altogether. |
397 | if (ub && *ub < divisor) |
398 | memo[pos] = dimExpr; |
399 | else |
400 | memo[pos] = dimExpr % divisor; |
401 | // If a unique quotient `var_q` was seen, it can be expressed as |
402 | // `var_n floordiv divisor`. |
403 | if (quotientCount == 1 && !memo[quotientPosition]) |
404 | memo[quotientPosition] = dimExpr.floorDiv(v: divisor) * quotientSign; |
405 | |
406 | return true; |
407 | } |
408 | } |
409 | return false; |
410 | } |
411 | |
412 | /// Check if the pos^th variable can be expressed as a floordiv of an affine |
413 | /// function of other variables (where the divisor is a positive constant) |
414 | /// given the initial set of expressions in `exprs`. If it can be, the |
415 | /// corresponding position in `exprs` is set as the detected affine expr. For |
416 | /// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can |
417 | /// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28 |
418 | /// <= i <= 32q + 31 => q = i floordiv 32. |
419 | static bool detectAsFloorDiv(const FlatLinearConstraints &cst, unsigned pos, |
420 | MLIRContext *context, |
421 | SmallVectorImpl<AffineExpr> &exprs) { |
422 | assert(pos < cst.getNumVars() && "invalid position"); |
423 | |
424 | // Get upper-lower bound pair for this variable. |
425 | SmallVector<bool, 8> foundRepr(cst.getNumVars(), false); |
426 | for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i) |
427 | if (exprs[i]) |
428 | foundRepr[i] = true; |
429 | |
430 | SmallVector<int64_t, 8> dividend(cst.getNumCols()); |
431 | unsigned divisor; |
432 | auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor); |
433 | |
434 | // No upper-lower bound pair found for this var. |
435 | if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality) |
436 | return false; |
437 | |
438 | // Construct the dividend expression. |
439 | auto dividendExpr = getAffineConstantExpr(constant: dividend.back(), context); |
440 | for (unsigned c = 0, f = cst.getNumVars(); c < f; c++) |
441 | if (dividend[c] != 0) |
442 | dividendExpr = dividendExpr + dividend[c] * exprs[c]; |
443 | |
444 | // Successfully detected the floordiv. |
445 | exprs[pos] = dividendExpr.floorDiv(v: divisor); |
446 | return true; |
447 | } |
448 | |
449 | void FlatLinearConstraints::dumpRow(ArrayRef<int64_t> row, |
450 | bool fixedColWidth) const { |
451 | unsigned ncols = getNumCols(); |
452 | bool firstNonZero = true; |
453 | for (unsigned j = 0; j < ncols; j++) { |
454 | if (j == ncols - 1) { |
455 | // Constant. |
456 | if (row[j] == 0 && !firstNonZero) { |
457 | if (fixedColWidth) |
458 | llvm::errs().indent(NumSpaces: 7); |
459 | } else { |
460 | llvm::errs() << ((row[j] >= 0) ? "+ ": "") << row[j] << ' '; |
461 | } |
462 | } else { |
463 | std::string var = std::string("c_") + std::to_string(val: j); |
464 | if (row[j] == 1) |
465 | llvm::errs() << "+ "<< var << ' '; |
466 | else if (row[j] == -1) |
467 | llvm::errs() << "- "<< var << ' '; |
468 | else if (row[j] >= 2) |
469 | llvm::errs() << "+ "<< row[j] << '*' << var << ' '; |
470 | else if (row[j] <= -2) |
471 | llvm::errs() << "- "<< -row[j] << '*' << var << ' '; |
472 | else if (fixedColWidth) |
473 | // Zero coeff. |
474 | llvm::errs().indent(NumSpaces: 7); |
475 | if (row[j] != 0) |
476 | firstNonZero = false; |
477 | } |
478 | } |
479 | } |
480 | |
481 | void FlatLinearConstraints::dumpPretty() const { |
482 | assert(hasConsistentState()); |
483 | llvm::errs() << "Constraints ("<< getNumDimVars() << " dims, " |
484 | << getNumSymbolVars() << " symbols, "<< getNumLocalVars() |
485 | << " locals), ("<< getNumConstraints() << " constraints)\n"; |
486 | auto dumpConstraint = [&](unsigned rowPos, bool isEq) { |
487 | // Is it the first non-zero entry? |
488 | SmallVector<int64_t> row = |
489 | isEq ? getEquality64(idx: rowPos) : getInequality64(idx: rowPos); |
490 | dumpRow(row); |
491 | llvm::errs() << (isEq ? "=": ">=") << " 0\n"; |
492 | }; |
493 | |
494 | for (unsigned i = 0, e = getNumInequalities(); i < e; i++) |
495 | dumpConstraint(i, /*isEq=*/false); |
496 | for (unsigned i = 0, e = getNumEqualities(); i < e; i++) |
497 | dumpConstraint(i, /*isEq=*/true); |
498 | llvm::errs() << '\n'; |
499 | } |
500 | |
501 | std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound( |
502 | unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, |
503 | ArrayRef<AffineExpr> localExprs, MLIRContext *context, |
504 | bool closedUB) const { |
505 | assert(pos + offset < getNumDimVars() && "invalid dim start pos"); |
506 | assert(symStartPos >= (pos + offset) && "invalid sym start pos"); |
507 | assert(getNumLocalVars() == localExprs.size() && |
508 | "incorrect local exprs count"); |
509 | |
510 | SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices; |
511 | getLowerAndUpperBoundIndices(pos: pos + offset, lbIndices: &lbIndices, ubIndices: &ubIndices, eqIndices: &eqIndices, |
512 | offset, num); |
513 | |
514 | /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos). |
515 | auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) { |
516 | b.clear(); |
517 | for (unsigned i = 0, e = a.size(); i < e; ++i) { |
518 | if (i < offset || i >= offset + num) |
519 | b.push_back(Elt: a[i]); |
520 | } |
521 | }; |
522 | |
523 | SmallVector<int64_t, 8> lb, ub; |
524 | SmallVector<AffineExpr, 4> lbExprs; |
525 | unsigned dimCount = symStartPos - num; |
526 | unsigned symCount = getNumDimAndSymbolVars() - symStartPos; |
527 | lbExprs.reserve(N: lbIndices.size() + eqIndices.size()); |
528 | // Lower bound expressions. |
529 | for (auto idx : lbIndices) { |
530 | auto ineq = getInequality64(idx); |
531 | // Extract the lower bound (in terms of other coeff's + const), i.e., if |
532 | // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j |
533 | // - 1. |
534 | addCoeffs(ineq, lb); |
535 | std::transform(first: lb.begin(), last: lb.end(), result: lb.begin(), unary_op: std::negate<int64_t>()); |
536 | auto expr = |
537 | getAffineExprFromFlatForm(flatExprs: lb, numDims: dimCount, numSymbols: symCount, localExprs, context); |
538 | // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor |
539 | int64_t divisor = std::abs(i: ineq[pos + offset]); |
540 | expr = (expr + divisor - 1).floorDiv(v: divisor); |
541 | lbExprs.push_back(Elt: expr); |
542 | } |
543 | |
544 | SmallVector<AffineExpr, 4> ubExprs; |
545 | ubExprs.reserve(N: ubIndices.size() + eqIndices.size()); |
546 | // Upper bound expressions. |
547 | for (auto idx : ubIndices) { |
548 | auto ineq = getInequality64(idx); |
549 | // Extract the upper bound (in terms of other coeff's + const). |
550 | addCoeffs(ineq, ub); |
551 | auto expr = |
552 | getAffineExprFromFlatForm(flatExprs: ub, numDims: dimCount, numSymbols: symCount, localExprs, context); |
553 | expr = expr.floorDiv(v: std::abs(i: ineq[pos + offset])); |
554 | int64_t ubAdjustment = closedUB ? 0 : 1; |
555 | ubExprs.push_back(Elt: expr + ubAdjustment); |
556 | } |
557 | |
558 | // Equalities. It's both a lower and a upper bound. |
559 | SmallVector<int64_t, 4> b; |
560 | for (auto idx : eqIndices) { |
561 | auto eq = getEquality64(idx); |
562 | addCoeffs(eq, b); |
563 | if (eq[pos + offset] > 0) |
564 | std::transform(first: b.begin(), last: b.end(), result: b.begin(), unary_op: std::negate<int64_t>()); |
565 | |
566 | // Extract the upper bound (in terms of other coeff's + const). |
567 | auto expr = |
568 | getAffineExprFromFlatForm(flatExprs: b, numDims: dimCount, numSymbols: symCount, localExprs, context); |
569 | expr = expr.floorDiv(v: std::abs(i: eq[pos + offset])); |
570 | // Upper bound is exclusive. |
571 | ubExprs.push_back(Elt: expr + 1); |
572 | // Lower bound. |
573 | expr = |
574 | getAffineExprFromFlatForm(flatExprs: b, numDims: dimCount, numSymbols: symCount, localExprs, context); |
575 | expr = expr.ceilDiv(v: std::abs(i: eq[pos + offset])); |
576 | lbExprs.push_back(Elt: expr); |
577 | } |
578 | |
579 | auto lbMap = AffineMap::get(dimCount, symbolCount: symCount, results: lbExprs, context); |
580 | auto ubMap = AffineMap::get(dimCount, symbolCount: symCount, results: ubExprs, context); |
581 | |
582 | return {lbMap, ubMap}; |
583 | } |
584 | |
585 | /// Express the pos^th identifier of `cst` as an affine expression in |
586 | /// terms of other identifiers, if they are available in `exprs`, using the |
587 | /// equality at position `idx` in `cs`t. Populates `exprs` with such an |
588 | /// expression if possible, and return true. Returns false otherwise. |
589 | static bool detectAsExpr(const FlatLinearConstraints &cst, unsigned pos, |
590 | unsigned idx, MLIRContext *context, |
591 | SmallVectorImpl<AffineExpr> &exprs) { |
592 | // Initialize with a `0` expression. |
593 | auto expr = getAffineConstantExpr(constant: 0, context); |
594 | |
595 | // Traverse `idx`th equality and construct the possible affine expression in |
596 | // terms of known identifiers. |
597 | unsigned j, e; |
598 | for (j = 0, e = cst.getNumVars(); j < e; ++j) { |
599 | if (j == pos) |
600 | continue; |
601 | int64_t c = cst.atEq64(i: idx, j); |
602 | if (c == 0) |
603 | continue; |
604 | // If any of the involved IDs hasn't been found yet, we can't proceed. |
605 | if (!exprs[j]) |
606 | break; |
607 | expr = expr + exprs[j] * c; |
608 | } |
609 | if (j < e) |
610 | // Can't construct expression as it depends on a yet uncomputed |
611 | // identifier. |
612 | return false; |
613 | |
614 | // Add constant term to AffineExpr. |
615 | expr = expr + cst.atEq64(i: idx, j: cst.getNumVars()); |
616 | int64_t vPos = cst.atEq64(i: idx, j: pos); |
617 | assert(vPos != 0 && "expected non-zero here"); |
618 | if (vPos > 0) |
619 | expr = (-expr).floorDiv(v: vPos); |
620 | else |
621 | // vPos < 0. |
622 | expr = expr.floorDiv(v: -vPos); |
623 | // Successfully constructed expression. |
624 | exprs[pos] = expr; |
625 | return true; |
626 | } |
627 | |
628 | /// Compute a representation of `num` identifiers starting at `offset` in `cst` |
629 | /// as affine expressions involving other known identifiers. Each identifier's |
630 | /// expression (in terms of known identifiers) is populated into `memo`. |
631 | static void computeUnknownVars(const FlatLinearConstraints &cst, |
632 | MLIRContext *context, unsigned offset, |
633 | unsigned num, |
634 | SmallVectorImpl<AffineExpr> &memo) { |
635 | // Initialize dimensional and symbolic variables. |
636 | for (unsigned i = 0, e = cst.getNumDimVars(); i < e; i++) { |
637 | if (i < offset) |
638 | memo[i] = getAffineDimExpr(position: i, context); |
639 | else if (i >= offset + num) |
640 | memo[i] = getAffineDimExpr(position: i - num, context); |
641 | } |
642 | for (unsigned i = cst.getNumDimVars(), e = cst.getNumDimAndSymbolVars(); |
643 | i < e; i++) |
644 | memo[i] = getAffineSymbolExpr(position: i - cst.getNumDimVars(), context); |
645 | |
646 | bool changed; |
647 | do { |
648 | changed = false; |
649 | // Identify yet unknown variables as constants or mod's / floordiv's of |
650 | // other variables if possible. |
651 | for (unsigned pos = 0, f = cst.getNumVars(); pos < f; pos++) { |
652 | if (memo[pos]) |
653 | continue; |
654 | |
655 | auto lbConst = cst.getConstantBound64(type: BoundType::LB, pos); |
656 | auto ubConst = cst.getConstantBound64(type: BoundType::UB, pos); |
657 | if (lbConst.has_value() && ubConst.has_value()) { |
658 | // Detect equality to a constant. |
659 | if (*lbConst == *ubConst) { |
660 | memo[pos] = getAffineConstantExpr(constant: *lbConst, context); |
661 | changed = true; |
662 | continue; |
663 | } |
664 | |
665 | // Detect a variable as modulo of another variable w.r.t a |
666 | // constant. |
667 | if (detectAsMod(cst, pos, offset, num, lbConst: *lbConst, ubConst: *ubConst, context, |
668 | memo)) { |
669 | changed = true; |
670 | continue; |
671 | } |
672 | } |
673 | |
674 | // Detect a variable as a floordiv of an affine function of other |
675 | // variables (divisor is a positive constant). |
676 | if (detectAsFloorDiv(cst, pos, context, exprs&: memo)) { |
677 | changed = true; |
678 | continue; |
679 | } |
680 | |
681 | // Detect a variable as an expression of other variables. |
682 | std::optional<unsigned> idx; |
683 | if (!(idx = cst.findConstraintWithNonZeroAt(colIdx: pos, /*isEq=*/true))) |
684 | continue; |
685 | |
686 | if (detectAsExpr(cst, pos, idx: *idx, context, exprs&: memo)) { |
687 | changed = true; |
688 | continue; |
689 | } |
690 | } |
691 | // This loop is guaranteed to reach a fixed point - since once an |
692 | // variable's explicit form is computed (in memo[pos]), it's not updated |
693 | // again. |
694 | } while (changed); |
695 | } |
696 | |
697 | /// Computes the lower and upper bounds of the first 'num' dimensional |
698 | /// variables (starting at 'offset') as affine maps of the remaining |
699 | /// variables (dimensional and symbolic variables). Local variables are |
700 | /// themselves explicitly computed as affine functions of other variables in |
701 | /// this process if needed. |
702 | void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num, |
703 | MLIRContext *context, |
704 | SmallVectorImpl<AffineMap> *lbMaps, |
705 | SmallVectorImpl<AffineMap> *ubMaps, |
706 | bool closedUB) { |
707 | assert(offset + num <= getNumDimVars() && "invalid range"); |
708 | |
709 | // Basic simplification. |
710 | normalizeConstraintsByGCD(); |
711 | |
712 | LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for variables at positions [" |
713 | << offset << ", "<< offset + num << ")\n"); |
714 | LLVM_DEBUG(dumpPretty()); |
715 | |
716 | // Record computed/detected variables. |
717 | SmallVector<AffineExpr, 8> memo(getNumVars()); |
718 | computeUnknownVars(cst: *this, context, offset, num, memo); |
719 | |
720 | int64_t ubAdjustment = closedUB ? 0 : 1; |
721 | |
722 | // Set the lower and upper bound maps for all the variables that were |
723 | // computed as affine expressions of the rest as the "detected expr" and |
724 | // "detected expr + 1" respectively; set the undetected ones to null. |
725 | std::optional<FlatLinearConstraints> tmpClone; |
726 | for (unsigned pos = 0; pos < num; pos++) { |
727 | unsigned numMapDims = getNumDimVars() - num; |
728 | unsigned numMapSymbols = getNumSymbolVars(); |
729 | AffineExpr expr = memo[pos + offset]; |
730 | if (expr) |
731 | expr = simplifyAffineExpr(expr, numDims: numMapDims, numSymbols: numMapSymbols); |
732 | |
733 | AffineMap &lbMap = (*lbMaps)[pos]; |
734 | AffineMap &ubMap = (*ubMaps)[pos]; |
735 | |
736 | if (expr) { |
737 | lbMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols, result: expr); |
738 | ubMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols, result: expr + ubAdjustment); |
739 | } else { |
740 | // TODO: Whenever there are local variables in the dependence |
741 | // constraints, we'll conservatively over-approximate, since we don't |
742 | // always explicitly compute them above (in the while loop). |
743 | if (getNumLocalVars() == 0) { |
744 | // Work on a copy so that we don't update this constraint system. |
745 | if (!tmpClone) { |
746 | tmpClone.emplace(args: FlatLinearConstraints(*this)); |
747 | // Removing redundant inequalities is necessary so that we don't get |
748 | // redundant loop bounds. |
749 | tmpClone->removeRedundantInequalities(); |
750 | } |
751 | std::tie(args&: lbMap, args&: ubMap) = tmpClone->getLowerAndUpperBound( |
752 | pos, offset, num, symStartPos: getNumDimVars(), /*localExprs=*/{}, context, |
753 | closedUB); |
754 | } |
755 | |
756 | // If the above fails, we'll just use the constant lower bound and the |
757 | // constant upper bound (if they exist) as the slice bounds. |
758 | // TODO: being conservative for the moment in cases that |
759 | // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is |
760 | // fixed (b/126426796). |
761 | if (!lbMap || lbMap.getNumResults() != 1) { |
762 | LLVM_DEBUG(llvm::dbgs() |
763 | << "WARNING: Potentially over-approximating slice lb\n"); |
764 | auto lbConst = getConstantBound64(type: BoundType::LB, pos: pos + offset); |
765 | if (lbConst.has_value()) { |
766 | lbMap = AffineMap::get(dimCount: numMapDims, symbolCount: numMapSymbols, |
767 | result: getAffineConstantExpr(constant: *lbConst, context)); |
768 | } |
769 | } |
770 | if (!ubMap || ubMap.getNumResults() != 1) { |
771 | LLVM_DEBUG(llvm::dbgs() |
772 | << "WARNING: Potentially over-approximating slice ub\n"); |
773 | auto ubConst = getConstantBound64(type: BoundType::UB, pos: pos + offset); |
774 | if (ubConst.has_value()) { |
775 | ubMap = AffineMap::get( |
776 | dimCount: numMapDims, symbolCount: numMapSymbols, |
777 | result: getAffineConstantExpr(constant: *ubConst + ubAdjustment, context)); |
778 | } |
779 | } |
780 | } |
781 | |
782 | LLVM_DEBUG(llvm::dbgs() << "Slice bounds:\n"); |
783 | LLVM_DEBUG(llvm::dbgs() << "lb map for pos = "<< Twine(pos + offset) |
784 | << ", expr: "<< lbMap << '\n'); |
785 | LLVM_DEBUG(llvm::dbgs() << "ub map for pos = "<< Twine(pos + offset) |
786 | << ", expr: "<< ubMap << '\n'); |
787 | } |
788 | } |
789 | |
790 | LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals( |
791 | AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, |
792 | bool addConservativeSemiAffineBounds) { |
793 | FlatLinearConstraints localCst; |
794 | if (failed(Result: getFlattenedAffineExprs(map, flattenedExprs, localVarCst: &localCst, |
795 | addConservativeSemiAffineBounds))) { |
796 | LLVM_DEBUG(llvm::dbgs() |
797 | << "composition unimplemented for semi-affine maps\n"); |
798 | return failure(); |
799 | } |
800 | |
801 | // Add localCst information. |
802 | if (localCst.getNumLocalVars() > 0) { |
803 | unsigned numLocalVars = getNumLocalVars(); |
804 | // Insert local dims of localCst at the beginning. |
805 | insertLocalVar(/*pos=*/0, /*num=*/localCst.getNumLocalVars()); |
806 | // Insert local dims of `this` at the end of localCst. |
807 | localCst.appendLocalVar(/*num=*/numLocalVars); |
808 | // Dimensions of localCst and this constraint set match. Append localCst to |
809 | // this constraint set. |
810 | append(other: localCst); |
811 | } |
812 | |
813 | return success(); |
814 | } |
815 | |
816 | LogicalResult FlatLinearConstraints::addBound( |
817 | BoundType type, unsigned pos, AffineMap boundMap, bool isClosedBound, |
818 | AddConservativeSemiAffineBounds addSemiAffineBounds) { |
819 | assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch"); |
820 | assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch"); |
821 | assert(pos < getNumDimAndSymbolVars() && "invalid position"); |
822 | assert((type != BoundType::EQ || isClosedBound) && |
823 | "EQ bound must be closed."); |
824 | |
825 | // Equality follows the logic of lower bound except that we add an equality |
826 | // instead of an inequality. |
827 | assert((type != BoundType::EQ || boundMap.getNumResults() == 1) && |
828 | "single result expected"); |
829 | bool lower = type == BoundType::LB || type == BoundType::EQ; |
830 | |
831 | std::vector<SmallVector<int64_t, 8>> flatExprs; |
832 | if (failed(Result: flattenAlignedMapAndMergeLocals( |
833 | map: boundMap, flattenedExprs: &flatExprs, |
834 | addConservativeSemiAffineBounds: addSemiAffineBounds == AddConservativeSemiAffineBounds::Yes))) |
835 | return failure(); |
836 | assert(flatExprs.size() == boundMap.getNumResults()); |
837 | |
838 | // Add one (in)equality for each result. |
839 | for (const auto &flatExpr : flatExprs) { |
840 | SmallVector<int64_t> ineq(getNumCols(), 0); |
841 | // Dims and symbols. |
842 | for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { |
843 | ineq[j] = lower ? -flatExpr[j] : flatExpr[j]; |
844 | } |
845 | // Invalid bound: pos appears in `boundMap`. |
846 | // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or |
847 | // its callers to prevent invalid bounds from being added. |
848 | if (ineq[pos] != 0) |
849 | continue; |
850 | ineq[pos] = lower ? 1 : -1; |
851 | // Local columns of `ineq` are at the beginning. |
852 | unsigned j = getNumDimVars() + getNumSymbolVars(); |
853 | unsigned end = flatExpr.size() - 1; |
854 | for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) { |
855 | ineq[j] = lower ? -flatExpr[i] : flatExpr[i]; |
856 | } |
857 | // Make the bound closed in if flatExpr is open. The inequality is always |
858 | // created in the upper bound form, so the adjustment is -1. |
859 | int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1; |
860 | // Constant term. |
861 | ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1] |
862 | : flatExpr[flatExpr.size() - 1]) + |
863 | boundAdjustment; |
864 | type == BoundType::EQ ? addEquality(eq: ineq) : addInequality(inEq: ineq); |
865 | } |
866 | |
867 | return success(); |
868 | } |
869 | |
870 | LogicalResult FlatLinearConstraints::addBound( |
871 | BoundType type, unsigned pos, AffineMap boundMap, |
872 | AddConservativeSemiAffineBounds addSemiAffineBounds) { |
873 | return addBound(type, pos, boundMap, |
874 | /*isClosedBound=*/type != BoundType::UB, addSemiAffineBounds); |
875 | } |
876 | |
877 | /// Compute an explicit representation for local vars. For all systems coming |
878 | /// from MLIR integer sets, maps, or expressions where local vars were |
879 | /// introduced to model floordivs and mods, this always succeeds. |
880 | LogicalResult |
881 | FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo, |
882 | MLIRContext *context) const { |
883 | unsigned numDims = getNumDimVars(); |
884 | unsigned numSyms = getNumSymbolVars(); |
885 | |
886 | // Initialize dimensional and symbolic variables. |
887 | for (unsigned i = 0; i < numDims; i++) |
888 | memo[i] = getAffineDimExpr(position: i, context); |
889 | for (unsigned i = numDims, e = numDims + numSyms; i < e; i++) |
890 | memo[i] = getAffineSymbolExpr(position: i - numDims, context); |
891 | |
892 | bool changed; |
893 | do { |
894 | // Each time `changed` is true at the end of this iteration, one or more |
895 | // local vars would have been detected as floordivs and set in memo; so the |
896 | // number of null entries in memo[...] strictly reduces; so this converges. |
897 | changed = false; |
898 | for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i) |
899 | if (!memo[numDims + numSyms + i] && |
900 | detectAsFloorDiv(cst: *this, /*pos=*/numDims + numSyms + i, context, exprs&: memo)) |
901 | changed = true; |
902 | } while (changed); |
903 | |
904 | ArrayRef<AffineExpr> localExprs = |
905 | ArrayRef<AffineExpr>(memo).take_back(N: getNumLocalVars()); |
906 | return success( |
907 | IsSuccess: llvm::all_of(Range&: localExprs, P: [](AffineExpr expr) { return expr; })); |
908 | } |
909 | |
910 | /// Given an equality or inequality (`isEquality` used to disambiguate) of `cst` |
911 | /// at `idx`, traverse and sum up `AffineExpr`s of all known ids other than the |
912 | /// `pos`th. Known `AffineExpr`s are given in `exprs` (unknowns are null). If |
913 | /// the equality/inequality contains any unknown id, return None. Otherwise |
914 | /// return sum as `AffineExpr`. |
915 | static std::optional<AffineExpr> getAsExpr(const FlatLinearConstraints &cst, |
916 | unsigned pos, MLIRContext *context, |
917 | ArrayRef<AffineExpr> exprs, |
918 | unsigned idx, bool isEquality) { |
919 | // Initialize with a `0` expression. |
920 | auto expr = getAffineConstantExpr(constant: 0, context); |
921 | |
922 | SmallVector<int64_t, 8> row = |
923 | isEquality ? cst.getEquality64(idx) : cst.getInequality64(idx); |
924 | |
925 | // Traverse `idx`th equality and construct the possible affine expression in |
926 | // terms of known identifiers. |
927 | unsigned j, e; |
928 | for (j = 0, e = cst.getNumVars(); j < e; ++j) { |
929 | if (j == pos) |
930 | continue; |
931 | int64_t c = row[j]; |
932 | if (c == 0) |
933 | continue; |
934 | // If any of the involved IDs hasn't been found yet, we can't proceed. |
935 | if (!exprs[j]) |
936 | break; |
937 | expr = expr + exprs[j] * c; |
938 | } |
939 | if (j < e) |
940 | // Can't construct expression as it depends on a yet uncomputed |
941 | // identifier. |
942 | return std::nullopt; |
943 | |
944 | // Add constant term to AffineExpr. |
945 | expr = expr + row[cst.getNumVars()]; |
946 | return expr; |
947 | } |
948 | |
949 | std::optional<int64_t> FlatLinearConstraints::getConstantBoundOnDimSize( |
950 | MLIRContext *context, unsigned pos, AffineMap *lb, AffineMap *ub, |
951 | unsigned *minLbPos, unsigned *minUbPos) const { |
952 | |
953 | assert(pos < getNumDimVars() && "Invalid identifier position"); |
954 | |
955 | auto freeOfUnknownLocalVars = [&](ArrayRef<int64_t> cst, |
956 | ArrayRef<AffineExpr> whiteListCols) { |
957 | for (int i = getNumDimAndSymbolVars(), e = cst.size() - 1; i < e; ++i) { |
958 | if (whiteListCols[i] && whiteListCols[i].isSymbolicOrConstant()) |
959 | continue; |
960 | if (cst[i] != 0) |
961 | return false; |
962 | } |
963 | return true; |
964 | }; |
965 | |
966 | // Detect the necesary local variables first. |
967 | SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr()); |
968 | (void)computeLocalVars(memo, context); |
969 | |
970 | // Find an equality for 'pos'^th identifier that equates it to some function |
971 | // of the symbolic identifiers (+ constant). |
972 | int eqPos = findEqualityToConstant(pos, /*symbolic=*/true); |
973 | // If the equality involves a local var that can not be expressed as a |
974 | // symbolic or constant affine expression, we bail out. |
975 | if (eqPos != -1 && freeOfUnknownLocalVars(getEquality64(idx: eqPos), memo)) { |
976 | // This identifier can only take a single value. |
977 | if (lb && detectAsExpr(cst: *this, pos, idx: eqPos, context, exprs&: memo)) { |
978 | AffineExpr equalityExpr = |
979 | simplifyAffineExpr(expr: memo[pos], numDims: 0, numSymbols: getNumSymbolVars()); |
980 | *lb = AffineMap::get(/*dimCount=*/0, symbolCount: getNumSymbolVars(), result: equalityExpr); |
981 | if (ub) |
982 | *ub = *lb; |
983 | } |
984 | if (minLbPos) |
985 | *minLbPos = eqPos; |
986 | if (minUbPos) |
987 | *minUbPos = eqPos; |
988 | return 1; |
989 | } |
990 | |
991 | // Positions of constraints that are lower/upper bounds on the variable. |
992 | SmallVector<unsigned, 4> lbIndices, ubIndices; |
993 | |
994 | // Note inequalities that give lower and upper bounds. |
995 | getLowerAndUpperBoundIndices(pos, lbIndices: &lbIndices, ubIndices: &ubIndices, |
996 | /*eqIndices=*/nullptr, /*offset=*/0, |
997 | /*num=*/getNumDimVars()); |
998 | |
999 | std::optional<int64_t> minDiff = std::nullopt; |
1000 | unsigned minLbPosition = 0, minUbPosition = 0; |
1001 | AffineExpr minLbExpr, minUbExpr; |
1002 | |
1003 | // Traverse each lower bound and upper bound pair, to compute the difference |
1004 | // between them. |
1005 | for (unsigned ubPos : ubIndices) { |
1006 | // Construct sum of all ids other than `pos`th in the given upper bound row. |
1007 | std::optional<AffineExpr> maybeUbExpr = |
1008 | getAsExpr(cst: *this, pos, context, exprs: memo, idx: ubPos, /*isEquality=*/false); |
1009 | if (!maybeUbExpr.has_value() || !(*maybeUbExpr).isSymbolicOrConstant()) |
1010 | continue; |
1011 | |
1012 | // Canonical form of an inequality that constrains the upper bound on |
1013 | // an id `x_i` is of the form: |
1014 | // `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` <= -1. |
1015 | // Therefore the upper bound on `x_i` will be |
1016 | // `( |
1017 | // sum(c_j*x_j) where j != i |
1018 | // + |
1019 | // c_0 |
1020 | // ) |
1021 | // / |
1022 | // -(c_i)`. Divison here is a floorDiv. |
1023 | AffineExpr ubExpr = maybeUbExpr->floorDiv(v: -atIneq64(i: ubPos, j: pos)); |
1024 | assert(-atIneq64(ubPos, pos) > 0 && "invalid upper bound index"); |
1025 | |
1026 | // Go over each lower bound. |
1027 | for (unsigned lbPos : lbIndices) { |
1028 | // Construct sum of all ids other than `pos`th in the given lower bound |
1029 | // row. |
1030 | std::optional<AffineExpr> maybeLbExpr = |
1031 | getAsExpr(cst: *this, pos, context, exprs: memo, idx: lbPos, /*isEquality=*/false); |
1032 | if (!maybeLbExpr.has_value() || !(*maybeLbExpr).isSymbolicOrConstant()) |
1033 | continue; |
1034 | |
1035 | // Canonical form of an inequality that is constraining the lower bound |
1036 | // on an id `x_i is of the form: |
1037 | // `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` >= 1. |
1038 | // Therefore upperBound on `x_i` will be |
1039 | // `-( |
1040 | // sum(c_j*x_j) where j != i |
1041 | // + |
1042 | // c_0 |
1043 | // ) |
1044 | // / |
1045 | // c_i`. Divison here is a ceilDiv. |
1046 | int64_t divisor = atIneq64(i: lbPos, j: pos); |
1047 | // We convert the `ceilDiv` for floordiv with the formula: |
1048 | // `expr ceildiv divisor is (expr + divisor - 1) floordiv divisor`, |
1049 | // since uniformly keeping divisons as `floorDiv` helps their |
1050 | // simplification. |
1051 | AffineExpr lbExpr = (-(*maybeLbExpr) + divisor - 1).floorDiv(v: divisor); |
1052 | assert(atIneq64(lbPos, pos) > 0 && "invalid lower bound index"); |
1053 | |
1054 | AffineExpr difference = |
1055 | simplifyAffineExpr(expr: ubExpr - lbExpr + 1, numDims: 0, numSymbols: getNumSymbolVars()); |
1056 | // If the difference is not constant, ignore the lower bound - upper bound |
1057 | // pair. |
1058 | auto constantDiff = dyn_cast<AffineConstantExpr>(Val&: difference); |
1059 | if (!constantDiff) |
1060 | continue; |
1061 | |
1062 | int64_t diffValue = constantDiff.getValue(); |
1063 | // This bound is non-negative by definition. |
1064 | diffValue = std::max<int64_t>(a: diffValue, b: 0); |
1065 | if (!minDiff || diffValue < *minDiff) { |
1066 | minDiff = diffValue; |
1067 | minLbPosition = lbPos; |
1068 | minUbPosition = ubPos; |
1069 | minLbExpr = lbExpr; |
1070 | minUbExpr = ubExpr; |
1071 | } |
1072 | } |
1073 | } |
1074 | |
1075 | // Populate outputs where available and needed. |
1076 | if (lb && minDiff) { |
1077 | *lb = AffineMap::get(/*dimCount=*/0, symbolCount: getNumSymbolVars(), result: minLbExpr); |
1078 | } |
1079 | if (ub) |
1080 | *ub = AffineMap::get(/*dimCount=*/0, symbolCount: getNumSymbolVars(), result: minUbExpr); |
1081 | if (minLbPos) |
1082 | *minLbPos = minLbPosition; |
1083 | if (minUbPos) |
1084 | *minUbPos = minUbPosition; |
1085 | |
1086 | return minDiff; |
1087 | } |
1088 | |
1089 | IntegerSet FlatLinearConstraints::getAsIntegerSet(MLIRContext *context) const { |
1090 | if (getNumConstraints() == 0) |
1091 | // Return universal set (always true): 0 == 0. |
1092 | return IntegerSet::get(dimCount: getNumDimVars(), symbolCount: getNumSymbolVars(), |
1093 | constraints: getAffineConstantExpr(/*constant=*/0, context), |
1094 | /*eqFlags=*/true); |
1095 | |
1096 | // Construct local references. |
1097 | SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr()); |
1098 | |
1099 | if (failed(Result: computeLocalVars(memo, context))) { |
1100 | // Check if the local variables without an explicit representation have |
1101 | // zero coefficients everywhere. |
1102 | SmallVector<unsigned> noLocalRepVars; |
1103 | unsigned numDimsSymbols = getNumDimAndSymbolVars(); |
1104 | for (unsigned i = numDimsSymbols, e = getNumVars(); i < e; ++i) { |
1105 | if (!memo[i] && !isColZero(/*pos=*/i)) |
1106 | noLocalRepVars.push_back(Elt: i - numDimsSymbols); |
1107 | } |
1108 | if (!noLocalRepVars.empty()) { |
1109 | LLVM_DEBUG({ |
1110 | llvm::dbgs() << "local variables at position(s) " |
1111 | << llvm::interleaved(noLocalRepVars) |
1112 | << " do not have an explicit representation in:\n"; |
1113 | this->dump(); |
1114 | }); |
1115 | return IntegerSet(); |
1116 | } |
1117 | } |
1118 | |
1119 | ArrayRef<AffineExpr> localExprs = |
1120 | ArrayRef<AffineExpr>(memo).take_back(N: getNumLocalVars()); |
1121 | |
1122 | // Construct the IntegerSet from the equalities/inequalities. |
1123 | unsigned numDims = getNumDimVars(); |
1124 | unsigned numSyms = getNumSymbolVars(); |
1125 | |
1126 | SmallVector<bool, 16> eqFlags(getNumConstraints()); |
1127 | std::fill(first: eqFlags.begin(), last: eqFlags.begin() + getNumEqualities(), value: true); |
1128 | std::fill(first: eqFlags.begin() + getNumEqualities(), last: eqFlags.end(), value: false); |
1129 | |
1130 | SmallVector<AffineExpr, 8> exprs; |
1131 | exprs.reserve(N: getNumConstraints()); |
1132 | |
1133 | for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) |
1134 | exprs.push_back(Elt: getAffineExprFromFlatForm(flatExprs: getEquality64(idx: i), numDims, |
1135 | numSymbols: numSyms, localExprs, context)); |
1136 | for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) |
1137 | exprs.push_back(Elt: getAffineExprFromFlatForm(flatExprs: getInequality64(idx: i), numDims, |
1138 | numSymbols: numSyms, localExprs, context)); |
1139 | return IntegerSet::get(dimCount: numDims, symbolCount: numSyms, constraints: exprs, eqFlags); |
1140 | } |
1141 | |
1142 | //===----------------------------------------------------------------------===// |
1143 | // FlatLinearValueConstraints |
1144 | //===----------------------------------------------------------------------===// |
1145 | |
1146 | // Construct from an IntegerSet. |
1147 | FlatLinearValueConstraints::FlatLinearValueConstraints(IntegerSet set, |
1148 | ValueRange operands) |
1149 | : FlatLinearConstraints(set.getNumInequalities(), set.getNumEqualities(), |
1150 | set.getNumDims() + set.getNumSymbols() + 1, |
1151 | set.getNumDims(), set.getNumSymbols(), |
1152 | /*numLocals=*/0) { |
1153 | assert((operands.empty() || set.getNumInputs() == operands.size()) && |
1154 | "operand count mismatch"); |
1155 | // Set the values for the non-local variables. |
1156 | for (unsigned i = 0, e = operands.size(); i < e; ++i) |
1157 | setValue(pos: i, val: operands[i]); |
1158 | |
1159 | // Flatten expressions and add them to the constraint system. |
1160 | std::vector<SmallVector<int64_t, 8>> flatExprs; |
1161 | FlatLinearConstraints localVarCst; |
1162 | if (failed(Result: getFlattenedAffineExprs(set, flattenedExprs: &flatExprs, localVarCst: &localVarCst))) { |
1163 | assert(false && "flattening unimplemented for semi-affine integer sets"); |
1164 | return; |
1165 | } |
1166 | assert(flatExprs.size() == set.getNumConstraints()); |
1167 | insertVar(kind: VarKind::Local, pos: getNumVarKind(kind: VarKind::Local), |
1168 | /*num=*/localVarCst.getNumLocalVars()); |
1169 | |
1170 | for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { |
1171 | const auto &flatExpr = flatExprs[i]; |
1172 | assert(flatExpr.size() == getNumCols()); |
1173 | if (set.getEqFlags()[i]) { |
1174 | addEquality(eq: flatExpr); |
1175 | } else { |
1176 | addInequality(inEq: flatExpr); |
1177 | } |
1178 | } |
1179 | // Add the other constraints involving local vars from flattening. |
1180 | append(other: localVarCst); |
1181 | } |
1182 | |
1183 | unsigned FlatLinearValueConstraints::appendDimVar(ValueRange vals) { |
1184 | unsigned pos = getNumDimVars(); |
1185 | return insertVar(kind: VarKind::SetDim, pos, vals); |
1186 | } |
1187 | |
1188 | unsigned FlatLinearValueConstraints::appendSymbolVar(ValueRange vals) { |
1189 | unsigned pos = getNumSymbolVars(); |
1190 | return insertVar(kind: VarKind::Symbol, pos, vals); |
1191 | } |
1192 | |
1193 | unsigned FlatLinearValueConstraints::insertDimVar(unsigned pos, |
1194 | ValueRange vals) { |
1195 | return insertVar(kind: VarKind::SetDim, pos, vals); |
1196 | } |
1197 | |
1198 | unsigned FlatLinearValueConstraints::insertSymbolVar(unsigned pos, |
1199 | ValueRange vals) { |
1200 | return insertVar(kind: VarKind::Symbol, pos, vals); |
1201 | } |
1202 | |
1203 | unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos, |
1204 | unsigned num) { |
1205 | unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num); |
1206 | |
1207 | return absolutePos; |
1208 | } |
1209 | |
1210 | unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos, |
1211 | ValueRange vals) { |
1212 | assert(!vals.empty() && "expected ValueRange with Values."); |
1213 | assert(kind != VarKind::Local && |
1214 | "values cannot be attached to local variables."); |
1215 | unsigned num = vals.size(); |
1216 | unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num); |
1217 | |
1218 | // If a Value is provided, insert it; otherwise use std::nullopt. |
1219 | for (unsigned i = 0, e = vals.size(); i < e; ++i) |
1220 | if (vals[i]) |
1221 | setValue(pos: absolutePos + i, val: vals[i]); |
1222 | |
1223 | return absolutePos; |
1224 | } |
1225 | |
1226 | /// Checks if two constraint systems are in the same space, i.e., if they are |
1227 | /// associated with the same set of variables, appearing in the same order. |
1228 | static bool areVarsAligned(const FlatLinearValueConstraints &a, |
1229 | const FlatLinearValueConstraints &b) { |
1230 | if (a.getNumDomainVars() != b.getNumDomainVars() || |
1231 | a.getNumRangeVars() != b.getNumRangeVars() || |
1232 | a.getNumSymbolVars() != b.getNumSymbolVars()) |
1233 | return false; |
1234 | SmallVector<std::optional<Value>> aMaybeValues = a.getMaybeValues(), |
1235 | bMaybeValues = b.getMaybeValues(); |
1236 | return std::equal(first1: aMaybeValues.begin(), last1: aMaybeValues.end(), |
1237 | first2: bMaybeValues.begin(), last2: bMaybeValues.end()); |
1238 | } |
1239 | |
1240 | /// Calls areVarsAligned to check if two constraint systems have the same set |
1241 | /// of variables in the same order. |
1242 | bool FlatLinearValueConstraints::areVarsAlignedWithOther( |
1243 | const FlatLinearConstraints &other) { |
1244 | return areVarsAligned(a: *this, b: other); |
1245 | } |
1246 | |
1247 | /// Checks if the SSA values associated with `cst`'s variables in range |
1248 | /// [start, end) are unique. |
1249 | static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( |
1250 | const FlatLinearValueConstraints &cst, unsigned start, unsigned end) { |
1251 | |
1252 | assert(start <= cst.getNumDimAndSymbolVars() && |
1253 | "Start position out of bounds"); |
1254 | assert(end <= cst.getNumDimAndSymbolVars() && "End position out of bounds"); |
1255 | |
1256 | if (start >= end) |
1257 | return true; |
1258 | |
1259 | SmallPtrSet<Value, 8> uniqueVars; |
1260 | SmallVector<std::optional<Value>, 8> maybeValuesAll = cst.getMaybeValues(); |
1261 | ArrayRef<std::optional<Value>> maybeValues = {maybeValuesAll.data() + start, |
1262 | maybeValuesAll.data() + end}; |
1263 | |
1264 | for (std::optional<Value> val : maybeValues) |
1265 | if (val && !uniqueVars.insert(Ptr: *val).second) |
1266 | return false; |
1267 | |
1268 | return true; |
1269 | } |
1270 | |
1271 | /// Checks if the SSA values associated with `cst`'s variables are unique. |
1272 | static bool LLVM_ATTRIBUTE_UNUSED |
1273 | areVarsUnique(const FlatLinearValueConstraints &cst) { |
1274 | return areVarsUnique(cst, start: 0, end: cst.getNumDimAndSymbolVars()); |
1275 | } |
1276 | |
1277 | /// Checks if the SSA values associated with `cst`'s variables of kind `kind` |
1278 | /// are unique. |
1279 | static bool LLVM_ATTRIBUTE_UNUSED |
1280 | areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) { |
1281 | |
1282 | if (kind == VarKind::SetDim) |
1283 | return areVarsUnique(cst, start: 0, end: cst.getNumDimVars()); |
1284 | if (kind == VarKind::Symbol) |
1285 | return areVarsUnique(cst, start: cst.getNumDimVars(), |
1286 | end: cst.getNumDimAndSymbolVars()); |
1287 | llvm_unreachable("Unexpected VarKind"); |
1288 | } |
1289 | |
1290 | /// Merge and align the variables of A and B starting at 'offset', so that |
1291 | /// both constraint systems get the union of the contained variables that is |
1292 | /// dimension-wise and symbol-wise unique; both constraint systems are updated |
1293 | /// so that they have the union of all variables, with A's original |
1294 | /// variables appearing first followed by any of B's variables that didn't |
1295 | /// appear in A. Local variables in B that have the same division |
1296 | /// representation as local variables in A are merged into one. We allow A |
1297 | /// and B to have non-unique values for their variables; in such cases, they are |
1298 | /// still aligned with the variables appearing first aligned with those |
1299 | /// appearing first in the other system from left to right. |
1300 | // E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M]) |
1301 | // Output: both A, B have (%i, %j, %k) [%M, %N, %P] |
1302 | static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a, |
1303 | FlatLinearValueConstraints *b) { |
1304 | assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars()); |
1305 | |
1306 | assert(llvm::all_of( |
1307 | llvm::drop_begin(a->getMaybeValues(), offset), |
1308 | [](const std::optional<Value> &var) { return var.has_value(); })); |
1309 | |
1310 | assert(llvm::all_of( |
1311 | llvm::drop_begin(b->getMaybeValues(), offset), |
1312 | [](const std::optional<Value> &var) { return var.has_value(); })); |
1313 | |
1314 | SmallVector<Value, 4> aDimValues; |
1315 | a->getValues(start: offset, end: a->getNumDimVars(), values: &aDimValues); |
1316 | |
1317 | { |
1318 | // Merge dims from A into B. |
1319 | unsigned d = offset; |
1320 | for (Value aDimValue : aDimValues) { |
1321 | unsigned loc; |
1322 | // Find from the position `d` since we'd like to also consider the |
1323 | // possibility of multiple variables with the same `Value`. We align with |
1324 | // the next appearing one. |
1325 | if (b->findVar(val: aDimValue, pos: &loc, offset: d)) { |
1326 | assert(loc >= offset && "A's dim appears in B's aligned range"); |
1327 | assert(loc < b->getNumDimVars() && |
1328 | "A's dim appears in B's non-dim position"); |
1329 | b->swapVar(posA: d, posB: loc); |
1330 | } else { |
1331 | b->insertDimVar(pos: d, vals: aDimValue); |
1332 | } |
1333 | d++; |
1334 | } |
1335 | // Dimensions that are in B, but not in A, are added at the end. |
1336 | for (unsigned t = a->getNumDimVars(), e = b->getNumDimVars(); t < e; t++) { |
1337 | a->appendDimVar(vals: b->getValue(pos: t)); |
1338 | } |
1339 | assert(a->getNumDimVars() == b->getNumDimVars() && |
1340 | "expected same number of dims"); |
1341 | } |
1342 | |
1343 | // Merge and align symbols of A and B |
1344 | a->mergeSymbolVars(other&: *b); |
1345 | // Merge and align locals of A and B |
1346 | a->mergeLocalVars(other&: *b); |
1347 | |
1348 | assert(areVarsAligned(*a, *b) && "IDs expected to be aligned"); |
1349 | } |
1350 | |
1351 | // Call 'mergeAndAlignVars' to align constraint systems of 'this' and 'other'. |
1352 | void FlatLinearValueConstraints::mergeAndAlignVarsWithOther( |
1353 | unsigned offset, FlatLinearValueConstraints *other) { |
1354 | mergeAndAlignVars(offset, a: this, b: other); |
1355 | } |
1356 | |
1357 | /// Merge and align symbols of `this` and `other` such that both get union of |
1358 | /// of symbols. Existing symbols need not be unique; they will be aligned from |
1359 | /// left to right with duplicates aligned in the same order. Symbols with Value |
1360 | /// as `None` are considered to be inequal to all other symbols. |
1361 | void FlatLinearValueConstraints::mergeSymbolVars( |
1362 | FlatLinearValueConstraints &other) { |
1363 | |
1364 | SmallVector<Value, 4> aSymValues; |
1365 | getValues(start: getNumDimVars(), end: getNumDimAndSymbolVars(), values: &aSymValues); |
1366 | |
1367 | // Merge symbols: merge symbols into `other` first from `this`. |
1368 | unsigned s = other.getNumDimVars(); |
1369 | for (Value aSymValue : aSymValues) { |
1370 | unsigned loc; |
1371 | // If the var is a symbol in `other`, then align it, otherwise assume that |
1372 | // it is a new symbol. Search in `other` starting at position `s` since the |
1373 | // left of it is aligned. |
1374 | if (other.findVar(val: aSymValue, pos: &loc, offset: s) && loc >= other.getNumDimVars() && |
1375 | loc < other.getNumDimAndSymbolVars()) |
1376 | other.swapVar(posA: s, posB: loc); |
1377 | else |
1378 | other.insertSymbolVar(pos: s - other.getNumDimVars(), vals: aSymValue); |
1379 | s++; |
1380 | } |
1381 | |
1382 | // Symbols that are in other, but not in this, are added at the end. |
1383 | for (unsigned t = other.getNumDimVars() + getNumSymbolVars(), |
1384 | e = other.getNumDimAndSymbolVars(); |
1385 | t < e; t++) |
1386 | insertSymbolVar(pos: getNumSymbolVars(), vals: other.getValue(pos: t)); |
1387 | |
1388 | assert(getNumSymbolVars() == other.getNumSymbolVars() && |
1389 | "expected same number of symbols"); |
1390 | } |
1391 | |
1392 | void FlatLinearValueConstraints::removeVarRange(VarKind kind, unsigned varStart, |
1393 | unsigned varLimit) { |
1394 | IntegerPolyhedron::removeVarRange(kind, varStart, varLimit); |
1395 | } |
1396 | |
1397 | AffineMap |
1398 | FlatLinearValueConstraints::computeAlignedMap(AffineMap map, |
1399 | ValueRange operands) const { |
1400 | assert(map.getNumInputs() == operands.size() && "number of inputs mismatch"); |
1401 | |
1402 | SmallVector<Value> dims, syms; |
1403 | #ifndef NDEBUG |
1404 | SmallVector<Value> newSyms; |
1405 | SmallVector<Value> *newSymsPtr = &newSyms; |
1406 | #else |
1407 | SmallVector<Value> *newSymsPtr = nullptr; |
1408 | #endif // NDEBUG |
1409 | |
1410 | dims.reserve(N: getNumDimVars()); |
1411 | syms.reserve(N: getNumSymbolVars()); |
1412 | for (unsigned i = 0, e = getNumVarKind(kind: VarKind::SetDim); i < e; ++i) { |
1413 | Identifier id = space.getId(kind: VarKind::SetDim, pos: i); |
1414 | dims.push_back(Elt: id.hasValue() ? Value(id.getValue<Value>()) : Value()); |
1415 | } |
1416 | for (unsigned i = 0, e = getNumVarKind(kind: VarKind::Symbol); i < e; ++i) { |
1417 | Identifier id = space.getId(kind: VarKind::Symbol, pos: i); |
1418 | syms.push_back(Elt: id.hasValue() ? Value(id.getValue<Value>()) : Value()); |
1419 | } |
1420 | |
1421 | AffineMap alignedMap = |
1422 | alignAffineMapWithValues(map, operands, dims, syms, newSyms: newSymsPtr); |
1423 | // All symbols are already part of this FlatAffineValueConstraints. |
1424 | assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols"); |
1425 | assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) && |
1426 | "unexpected new/missing symbols"); |
1427 | return alignedMap; |
1428 | } |
1429 | |
1430 | bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos, |
1431 | unsigned offset) const { |
1432 | SmallVector<std::optional<Value>> maybeValues = getMaybeValues(); |
1433 | for (unsigned i = offset, e = maybeValues.size(); i < e; ++i) |
1434 | if (maybeValues[i] && maybeValues[i].value() == val) { |
1435 | *pos = i; |
1436 | return true; |
1437 | } |
1438 | return false; |
1439 | } |
1440 | |
1441 | bool FlatLinearValueConstraints::containsVar(Value val) const { |
1442 | unsigned pos; |
1443 | return findVar(val, pos: &pos, offset: 0); |
1444 | } |
1445 | |
1446 | void FlatLinearValueConstraints::addBound(BoundType type, Value val, |
1447 | int64_t value) { |
1448 | unsigned pos; |
1449 | if (!findVar(val, pos: &pos)) |
1450 | // This is a pre-condition for this method. |
1451 | assert(0 && "var not found"); |
1452 | addBound(type, pos, value); |
1453 | } |
1454 | |
1455 | void FlatLinearConstraints::printSpace(raw_ostream &os) const { |
1456 | IntegerPolyhedron::printSpace(os); |
1457 | os << "("; |
1458 | for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) |
1459 | os << "None\t"; |
1460 | for (unsigned i = getVarKindOffset(kind: VarKind::Local), |
1461 | e = getVarKindEnd(kind: VarKind::Local); |
1462 | i < e; ++i) |
1463 | os << "Local\t"; |
1464 | os << "const)\n"; |
1465 | } |
1466 | |
1467 | void FlatLinearValueConstraints::printSpace(raw_ostream &os) const { |
1468 | IntegerPolyhedron::printSpace(os); |
1469 | os << "("; |
1470 | for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) { |
1471 | if (hasValue(pos: i)) |
1472 | os << "Value\t"; |
1473 | else |
1474 | os << "None\t"; |
1475 | } |
1476 | for (unsigned i = getVarKindOffset(kind: VarKind::Local), |
1477 | e = getVarKindEnd(kind: VarKind::Local); |
1478 | i < e; ++i) |
1479 | os << "Local\t"; |
1480 | os << "const)\n"; |
1481 | } |
1482 | |
1483 | void FlatLinearValueConstraints::projectOut(Value val) { |
1484 | unsigned pos; |
1485 | bool ret = findVar(val, pos: &pos); |
1486 | assert(ret); |
1487 | (void)ret; |
1488 | fourierMotzkinEliminate(pos); |
1489 | } |
1490 | |
1491 | LogicalResult FlatLinearValueConstraints::unionBoundingBox( |
1492 | const FlatLinearValueConstraints &otherCst) { |
1493 | assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch"); |
1494 | SmallVector<std::optional<Value>> maybeValues = getMaybeValues(), |
1495 | otherMaybeValues = |
1496 | otherCst.getMaybeValues(); |
1497 | assert(std::equal(maybeValues.begin(), maybeValues.begin() + getNumDimVars(), |
1498 | otherMaybeValues.begin(), |
1499 | otherMaybeValues.begin() + getNumDimVars()) && |
1500 | "dim values mismatch"); |
1501 | assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here"); |
1502 | assert(getNumLocalVars() == 0 && "local vars not supported yet here"); |
1503 | |
1504 | // Align `other` to this. |
1505 | if (!areVarsAligned(a: *this, b: otherCst)) { |
1506 | FlatLinearValueConstraints otherCopy(otherCst); |
1507 | mergeAndAlignVars(/*offset=*/getNumDimVars(), a: this, b: &otherCopy); |
1508 | return IntegerPolyhedron::unionBoundingBox(other: otherCopy); |
1509 | } |
1510 | |
1511 | return IntegerPolyhedron::unionBoundingBox(other: otherCst); |
1512 | } |
1513 | |
1514 | //===----------------------------------------------------------------------===// |
1515 | // Helper functions |
1516 | //===----------------------------------------------------------------------===// |
1517 | |
1518 | AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands, |
1519 | ValueRange dims, ValueRange syms, |
1520 | SmallVector<Value> *newSyms) { |
1521 | assert(operands.size() == map.getNumInputs() && |
1522 | "expected same number of operands and map inputs"); |
1523 | MLIRContext *ctx = map.getContext(); |
1524 | Builder builder(ctx); |
1525 | SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {}); |
1526 | unsigned numSymbols = syms.size(); |
1527 | SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {}); |
1528 | if (newSyms) { |
1529 | newSyms->clear(); |
1530 | newSyms->append(in_start: syms.begin(), in_end: syms.end()); |
1531 | } |
1532 | |
1533 | for (const auto &operand : llvm::enumerate(First&: operands)) { |
1534 | // Compute replacement dim/sym of operand. |
1535 | AffineExpr replacement; |
1536 | auto dimIt = llvm::find(Range&: dims, Val: operand.value()); |
1537 | auto symIt = llvm::find(Range&: syms, Val: operand.value()); |
1538 | if (dimIt != dims.end()) { |
1539 | replacement = |
1540 | builder.getAffineDimExpr(position: std::distance(first: dims.begin(), last: dimIt)); |
1541 | } else if (symIt != syms.end()) { |
1542 | replacement = |
1543 | builder.getAffineSymbolExpr(position: std::distance(first: syms.begin(), last: symIt)); |
1544 | } else { |
1545 | // This operand is neither a dimension nor a symbol. Add it as a new |
1546 | // symbol. |
1547 | replacement = builder.getAffineSymbolExpr(position: numSymbols++); |
1548 | if (newSyms) |
1549 | newSyms->push_back(Elt: operand.value()); |
1550 | } |
1551 | // Add to corresponding replacements vector. |
1552 | if (operand.index() < map.getNumDims()) { |
1553 | dimReplacements[operand.index()] = replacement; |
1554 | } else { |
1555 | symReplacements[operand.index() - map.getNumDims()] = replacement; |
1556 | } |
1557 | } |
1558 | |
1559 | return map.replaceDimsAndSymbols(dimReplacements, symReplacements, |
1560 | numResultDims: dims.size(), numResultSyms: numSymbols); |
1561 | } |
1562 | |
1563 | LogicalResult |
1564 | mlir::getMultiAffineFunctionFromMap(AffineMap map, |
1565 | MultiAffineFunction &multiAff) { |
1566 | FlatLinearConstraints cst; |
1567 | std::vector<SmallVector<int64_t, 8>> flattenedExprs; |
1568 | LogicalResult result = getFlattenedAffineExprs(map, flattenedExprs: &flattenedExprs, localVarCst: &cst); |
1569 | |
1570 | if (result.failed()) |
1571 | return failure(); |
1572 | |
1573 | DivisionRepr divs = cst.getLocalReprs(); |
1574 | assert(divs.hasAllReprs() && |
1575 | "AffineMap cannot produce divs without local representation"); |
1576 | |
1577 | // TODO: We shouldn't have to do this conversion. |
1578 | Matrix<DynamicAPInt> mat(map.getNumResults(), |
1579 | map.getNumInputs() + divs.getNumDivs() + 1); |
1580 | for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i) |
1581 | for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j) |
1582 | mat(i, j) = flattenedExprs[i][j]; |
1583 | |
1584 | multiAff = MultiAffineFunction( |
1585 | PresburgerSpace::getRelationSpace(numDomain: map.getNumDims(), numRange: map.getNumResults(), |
1586 | numSymbols: map.getNumSymbols(), numLocals: divs.getNumDivs()), |
1587 | mat, divs); |
1588 | |
1589 | return success(); |
1590 | } |
1591 |
Definitions
- AffineExprFlattener
- AffineExprFlattener
- addLocalFloorDivId
- addLocalIdSemiAffine
- SemiAffineExprFlattener
- addLocalIdSemiAffine
- getFlattenedAffineExprs
- getFlattenedAffineExpr
- getFlattenedAffineExprs
- getFlattenedAffineExprs
- composeMatchingMap
- detectAsMod
- detectAsFloorDiv
- dumpRow
- dumpPretty
- getLowerAndUpperBound
- detectAsExpr
- computeUnknownVars
- getSliceBounds
- flattenAlignedMapAndMergeLocals
- addBound
- addBound
- computeLocalVars
- getAsExpr
- getConstantBoundOnDimSize
- getAsIntegerSet
- FlatLinearValueConstraints
- appendDimVar
- appendSymbolVar
- insertDimVar
- insertSymbolVar
- insertVar
- insertVar
- areVarsAligned
- areVarsAlignedWithOther
- areVarsUnique
- areVarsUnique
- areVarsUnique
- mergeAndAlignVars
- mergeAndAlignVarsWithOther
- mergeSymbolVars
- removeVarRange
- computeAlignedMap
- findVar
- containsVar
- addBound
- printSpace
- printSpace
- projectOut
- unionBoundingBox
- alignAffineMapWithValues
Improve your Profiling and Debugging skills
Find out more