1 | //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/AffineMap.h" |
10 | #include "AffineMapDetail.h" |
11 | #include "mlir/IR/AffineExpr.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/BuiltinAttributes.h" |
14 | #include "mlir/IR/BuiltinTypes.h" |
15 | #include "mlir/Support/LogicalResult.h" |
16 | #include "mlir/Support/MathExtras.h" |
17 | #include "llvm/ADT/STLExtras.h" |
18 | #include "llvm/ADT/SmallBitVector.h" |
19 | #include "llvm/ADT/SmallSet.h" |
20 | #include "llvm/ADT/SmallVector.h" |
21 | #include "llvm/ADT/StringRef.h" |
22 | #include "llvm/Support/raw_ostream.h" |
23 | #include <iterator> |
24 | #include <numeric> |
25 | #include <optional> |
26 | #include <type_traits> |
27 | |
28 | using namespace mlir; |
29 | |
30 | namespace { |
31 | |
32 | // AffineExprConstantFolder evaluates an affine expression using constant |
33 | // operands passed in 'operandConsts'. Returns an IntegerAttr attribute |
34 | // representing the constant value of the affine expression evaluated on |
35 | // constant 'operandConsts', or nullptr if it can't be folded. |
36 | class AffineExprConstantFolder { |
37 | public: |
38 | AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts) |
39 | : numDims(numDims), operandConsts(operandConsts) {} |
40 | |
41 | /// Attempt to constant fold the specified affine expr, or return null on |
42 | /// failure. |
43 | IntegerAttr constantFold(AffineExpr expr) { |
44 | if (auto result = constantFoldImpl(expr)) |
45 | return IntegerAttr::get(IndexType::get(expr.getContext()), *result); |
46 | return nullptr; |
47 | } |
48 | |
49 | bool hasPoison() const { return hasPoison_; } |
50 | |
51 | private: |
52 | std::optional<int64_t> constantFoldImpl(AffineExpr expr) { |
53 | switch (expr.getKind()) { |
54 | case AffineExprKind::Add: |
55 | return constantFoldBinExpr( |
56 | expr, op: [](int64_t lhs, int64_t rhs) { return lhs + rhs; }); |
57 | case AffineExprKind::Mul: |
58 | return constantFoldBinExpr( |
59 | expr, op: [](int64_t lhs, int64_t rhs) { return lhs * rhs; }); |
60 | case AffineExprKind::Mod: |
61 | return constantFoldBinExpr( |
62 | expr, op: [this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> { |
63 | if (rhs < 1) { |
64 | hasPoison_ = true; |
65 | return std::nullopt; |
66 | } |
67 | return mod(lhs, rhs); |
68 | }); |
69 | case AffineExprKind::FloorDiv: |
70 | return constantFoldBinExpr( |
71 | expr, op: [this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> { |
72 | if (rhs == 0) { |
73 | hasPoison_ = true; |
74 | return std::nullopt; |
75 | } |
76 | return floorDiv(lhs, rhs); |
77 | }); |
78 | case AffineExprKind::CeilDiv: |
79 | return constantFoldBinExpr( |
80 | expr, op: [this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> { |
81 | if (rhs == 0) { |
82 | hasPoison_ = true; |
83 | return std::nullopt; |
84 | } |
85 | return ceilDiv(lhs, rhs); |
86 | }); |
87 | case AffineExprKind::Constant: |
88 | return cast<AffineConstantExpr>(Val&: expr).getValue(); |
89 | case AffineExprKind::DimId: |
90 | if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>( |
91 | operandConsts[cast<AffineDimExpr>(expr).getPosition()])) |
92 | return attr.getInt(); |
93 | return std::nullopt; |
94 | case AffineExprKind::SymbolId: |
95 | if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>( |
96 | operandConsts[numDims + |
97 | cast<AffineSymbolExpr>(expr).getPosition()])) |
98 | return attr.getInt(); |
99 | return std::nullopt; |
100 | } |
101 | llvm_unreachable("Unknown AffineExpr" ); |
102 | } |
103 | |
104 | // TODO: Change these to operate on APInts too. |
105 | std::optional<int64_t> constantFoldBinExpr( |
106 | AffineExpr expr, |
107 | llvm::function_ref<std::optional<int64_t>(int64_t, int64_t)> op) { |
108 | auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr); |
109 | if (auto lhs = constantFoldImpl(expr: binOpExpr.getLHS())) |
110 | if (auto rhs = constantFoldImpl(expr: binOpExpr.getRHS())) |
111 | return op(*lhs, *rhs); |
112 | return std::nullopt; |
113 | } |
114 | |
115 | // The number of dimension operands in AffineMap containing this expression. |
116 | unsigned numDims; |
117 | // The constant valued operands used to evaluate this AffineExpr. |
118 | ArrayRef<Attribute> operandConsts; |
119 | bool hasPoison_{false}; |
120 | }; |
121 | |
122 | } // namespace |
123 | |
124 | /// Returns a single constant result affine map. |
125 | AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) { |
126 | return get(/*dimCount=*/0, /*symbolCount=*/0, |
127 | result: {getAffineConstantExpr(constant: val, context)}); |
128 | } |
129 | |
130 | /// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most |
131 | /// minor dimensions. |
132 | AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results, |
133 | MLIRContext *context) { |
134 | assert(dims >= results && "Dimension mismatch" ); |
135 | auto id = AffineMap::getMultiDimIdentityMap(numDims: dims, context); |
136 | return AffineMap::get(dimCount: dims, symbolCount: 0, results: id.getResults().take_back(N: results), context); |
137 | } |
138 | |
139 | AffineMap AffineMap::getFilteredIdentityMap( |
140 | MLIRContext *ctx, unsigned numDims, |
141 | llvm::function_ref<bool(AffineDimExpr)> keepDimFilter) { |
142 | auto identityMap = getMultiDimIdentityMap(numDims, context: ctx); |
143 | |
144 | // Apply filter to results. |
145 | llvm::SmallBitVector dropDimResults(numDims); |
146 | for (auto [idx, resultExpr] : llvm::enumerate(First: identityMap.getResults())) |
147 | dropDimResults[idx] = !keepDimFilter(cast<AffineDimExpr>(Val: resultExpr)); |
148 | |
149 | return identityMap.dropResults(positions: dropDimResults); |
150 | } |
151 | |
152 | bool AffineMap::isMinorIdentity() const { |
153 | return getNumDims() >= getNumResults() && |
154 | *this == |
155 | getMinorIdentityMap(dims: getNumDims(), results: getNumResults(), context: getContext()); |
156 | } |
157 | |
158 | /// Returns true if this affine map is a minor identity up to broadcasted |
159 | /// dimensions which are indicated by value 0 in the result. |
160 | bool AffineMap::isMinorIdentityWithBroadcasting( |
161 | SmallVectorImpl<unsigned> *broadcastedDims) const { |
162 | if (broadcastedDims) |
163 | broadcastedDims->clear(); |
164 | if (getNumDims() < getNumResults()) |
165 | return false; |
166 | unsigned suffixStart = getNumDims() - getNumResults(); |
167 | for (const auto &idxAndExpr : llvm::enumerate(First: getResults())) { |
168 | unsigned resIdx = idxAndExpr.index(); |
169 | AffineExpr expr = idxAndExpr.value(); |
170 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) { |
171 | // Each result may be either a constant 0 (broadcasted dimension). |
172 | if (constExpr.getValue() != 0) |
173 | return false; |
174 | if (broadcastedDims) |
175 | broadcastedDims->push_back(Elt: resIdx); |
176 | } else if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) { |
177 | // Or it may be the input dimension corresponding to this result position. |
178 | if (dimExpr.getPosition() != suffixStart + resIdx) |
179 | return false; |
180 | } else { |
181 | return false; |
182 | } |
183 | } |
184 | return true; |
185 | } |
186 | |
187 | /// Return true if this affine map can be converted to a minor identity with |
188 | /// broadcast by doing a permute. Return a permutation (there may be |
189 | /// several) to apply to get to a minor identity with broadcasts. |
190 | /// Ex: |
191 | /// * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with |
192 | /// perm = [1, 0] and broadcast d2 |
193 | /// * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by |
194 | /// permutation + broadcast |
195 | /// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3) |
196 | /// with perm = [1, 0, 2] and broadcast d2 |
197 | /// * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra |
198 | /// leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with |
199 | /// perm = [3, 0, 1, 2] |
200 | bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting( |
201 | SmallVectorImpl<unsigned> &permutedDims) const { |
202 | unsigned projectionStart = |
203 | getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0; |
204 | permutedDims.clear(); |
205 | SmallVector<unsigned> broadcastDims; |
206 | permutedDims.resize(N: getNumResults(), NV: 0); |
207 | // If there are more results than input dimensions we want the new map to |
208 | // start with broadcast dimensions in order to be a minor identity with |
209 | // broadcasting. |
210 | unsigned leadingBroadcast = |
211 | getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0; |
212 | llvm::SmallBitVector dimFound(std::max(a: getNumInputs(), b: getNumResults()), |
213 | false); |
214 | for (const auto &idxAndExpr : llvm::enumerate(First: getResults())) { |
215 | unsigned resIdx = idxAndExpr.index(); |
216 | AffineExpr expr = idxAndExpr.value(); |
217 | // Each result may be either a constant 0 (broadcast dimension) or a |
218 | // dimension. |
219 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) { |
220 | if (constExpr.getValue() != 0) |
221 | return false; |
222 | broadcastDims.push_back(Elt: resIdx); |
223 | } else if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) { |
224 | if (dimExpr.getPosition() < projectionStart) |
225 | return false; |
226 | unsigned newPosition = |
227 | dimExpr.getPosition() - projectionStart + leadingBroadcast; |
228 | permutedDims[resIdx] = newPosition; |
229 | dimFound[newPosition] = true; |
230 | } else { |
231 | return false; |
232 | } |
233 | } |
234 | // Find a permuation for the broadcast dimension. Since they are broadcasted |
235 | // any valid permutation is acceptable. We just permute the dim into a slot |
236 | // without an existing dimension. |
237 | unsigned pos = 0; |
238 | for (auto dim : broadcastDims) { |
239 | while (pos < dimFound.size() && dimFound[pos]) { |
240 | pos++; |
241 | } |
242 | permutedDims[dim] = pos++; |
243 | } |
244 | return true; |
245 | } |
246 | |
247 | /// Returns an AffineMap representing a permutation. |
248 | AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation, |
249 | MLIRContext *context) { |
250 | assert(!permutation.empty() && |
251 | "Cannot create permutation map from empty permutation vector" ); |
252 | const auto *m = llvm::max_element(Range&: permutation); |
253 | auto permutationMap = getMultiDimMapWithTargets(numDims: *m + 1, targets: permutation, context); |
254 | assert(permutationMap.isPermutation() && "Invalid permutation vector" ); |
255 | return permutationMap; |
256 | } |
257 | AffineMap AffineMap::getPermutationMap(ArrayRef<int64_t> permutation, |
258 | MLIRContext *context) { |
259 | SmallVector<unsigned> perm = llvm::map_to_vector( |
260 | C&: permutation, F: [](int64_t i) { return static_cast<unsigned>(i); }); |
261 | return AffineMap::getPermutationMap(permutation: perm, context); |
262 | } |
263 | |
264 | AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims, |
265 | ArrayRef<unsigned> targets, |
266 | MLIRContext *context) { |
267 | SmallVector<AffineExpr, 4> affExprs; |
268 | for (unsigned t : targets) |
269 | affExprs.push_back(Elt: getAffineDimExpr(position: t, context)); |
270 | AffineMap result = AffineMap::get(/*dimCount=*/numDims, /*symbolCount=*/0, |
271 | results: affExprs, context); |
272 | return result; |
273 | } |
274 | |
275 | /// Creates an affine map each for each list of AffineExpr's in `exprsList` |
276 | /// while inferring the right number of dimensional and symbolic inputs needed |
277 | /// based on the maximum dimensional and symbolic identifier appearing in the |
278 | /// expressions. |
279 | template <typename AffineExprContainer> |
280 | static SmallVector<AffineMap, 4> |
281 | inferFromExprList(ArrayRef<AffineExprContainer> exprsList, |
282 | MLIRContext *context) { |
283 | if (exprsList.empty()) |
284 | return {}; |
285 | int64_t maxDim = -1, maxSym = -1; |
286 | getMaxDimAndSymbol(exprsList, maxDim, maxSym); |
287 | SmallVector<AffineMap, 4> maps; |
288 | maps.reserve(N: exprsList.size()); |
289 | for (const auto &exprs : exprsList) |
290 | maps.push_back(Elt: AffineMap::get(/*dimCount=*/maxDim + 1, |
291 | /*symbolCount=*/maxSym + 1, exprs, context)); |
292 | return maps; |
293 | } |
294 | |
295 | SmallVector<AffineMap, 4> |
296 | AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList, |
297 | MLIRContext *context) { |
298 | return ::inferFromExprList(exprsList, context); |
299 | } |
300 | |
301 | SmallVector<AffineMap, 4> |
302 | AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList, |
303 | MLIRContext *context) { |
304 | return ::inferFromExprList(exprsList, context); |
305 | } |
306 | |
307 | uint64_t AffineMap::getLargestKnownDivisorOfMapExprs() { |
308 | uint64_t gcd = 0; |
309 | for (AffineExpr resultExpr : getResults()) { |
310 | uint64_t thisGcd = resultExpr.getLargestKnownDivisor(); |
311 | gcd = std::gcd(m: gcd, n: thisGcd); |
312 | } |
313 | if (gcd == 0) |
314 | gcd = std::numeric_limits<uint64_t>::max(); |
315 | return gcd; |
316 | } |
317 | |
318 | AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, |
319 | MLIRContext *context) { |
320 | SmallVector<AffineExpr, 4> dimExprs; |
321 | dimExprs.reserve(N: numDims); |
322 | for (unsigned i = 0; i < numDims; ++i) |
323 | dimExprs.push_back(Elt: mlir::getAffineDimExpr(position: i, context)); |
324 | return get(/*dimCount=*/numDims, /*symbolCount=*/0, results: dimExprs, context); |
325 | } |
326 | |
327 | MLIRContext *AffineMap::getContext() const { return map->context; } |
328 | |
329 | bool AffineMap::isIdentity() const { |
330 | if (getNumDims() != getNumResults()) |
331 | return false; |
332 | ArrayRef<AffineExpr> results = getResults(); |
333 | for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) { |
334 | auto expr = dyn_cast<AffineDimExpr>(Val: results[i]); |
335 | if (!expr || expr.getPosition() != i) |
336 | return false; |
337 | } |
338 | return true; |
339 | } |
340 | |
341 | bool AffineMap::isSymbolIdentity() const { |
342 | if (getNumSymbols() != getNumResults()) |
343 | return false; |
344 | ArrayRef<AffineExpr> results = getResults(); |
345 | for (unsigned i = 0, numSymbols = getNumSymbols(); i < numSymbols; ++i) { |
346 | auto expr = dyn_cast<AffineDimExpr>(Val: results[i]); |
347 | if (!expr || expr.getPosition() != i) |
348 | return false; |
349 | } |
350 | return true; |
351 | } |
352 | |
353 | bool AffineMap::isEmpty() const { |
354 | return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0; |
355 | } |
356 | |
357 | bool AffineMap::isSingleConstant() const { |
358 | return getNumResults() == 1 && isa<AffineConstantExpr>(Val: getResult(idx: 0)); |
359 | } |
360 | |
361 | bool AffineMap::isConstant() const { |
362 | return llvm::all_of(Range: getResults(), P: llvm::IsaPred<AffineConstantExpr>); |
363 | } |
364 | |
365 | int64_t AffineMap::getSingleConstantResult() const { |
366 | assert(isSingleConstant() && "map must have a single constant result" ); |
367 | return cast<AffineConstantExpr>(Val: getResult(idx: 0)).getValue(); |
368 | } |
369 | |
370 | SmallVector<int64_t> AffineMap::getConstantResults() const { |
371 | assert(isConstant() && "map must have only constant results" ); |
372 | SmallVector<int64_t> result; |
373 | for (auto expr : getResults()) |
374 | result.emplace_back(Args: cast<AffineConstantExpr>(Val&: expr).getValue()); |
375 | return result; |
376 | } |
377 | |
378 | unsigned AffineMap::getNumDims() const { |
379 | assert(map && "uninitialized map storage" ); |
380 | return map->numDims; |
381 | } |
382 | unsigned AffineMap::getNumSymbols() const { |
383 | assert(map && "uninitialized map storage" ); |
384 | return map->numSymbols; |
385 | } |
386 | unsigned AffineMap::getNumResults() const { return getResults().size(); } |
387 | unsigned AffineMap::getNumInputs() const { |
388 | assert(map && "uninitialized map storage" ); |
389 | return map->numDims + map->numSymbols; |
390 | } |
391 | ArrayRef<AffineExpr> AffineMap::getResults() const { |
392 | assert(map && "uninitialized map storage" ); |
393 | return map->results(); |
394 | } |
395 | AffineExpr AffineMap::getResult(unsigned idx) const { |
396 | return getResults()[idx]; |
397 | } |
398 | |
399 | unsigned AffineMap::getDimPosition(unsigned idx) const { |
400 | return cast<AffineDimExpr>(Val: getResult(idx)).getPosition(); |
401 | } |
402 | |
403 | std::optional<unsigned> AffineMap::getResultPosition(AffineExpr input) const { |
404 | if (!isa<AffineDimExpr>(Val: input)) |
405 | return std::nullopt; |
406 | |
407 | for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) { |
408 | if (getResult(idx: i) == input) |
409 | return i; |
410 | } |
411 | |
412 | return std::nullopt; |
413 | } |
414 | |
415 | /// Folds the results of the application of an affine map on the provided |
416 | /// operands to a constant if possible. Returns false if the folding happens, |
417 | /// true otherwise. |
418 | LogicalResult AffineMap::constantFold(ArrayRef<Attribute> operandConstants, |
419 | SmallVectorImpl<Attribute> &results, |
420 | bool *hasPoison) const { |
421 | // Attempt partial folding. |
422 | SmallVector<int64_t, 2> integers; |
423 | partialConstantFold(operandConstants, results: &integers, hasPoison); |
424 | |
425 | // If all expressions folded to a constant, populate results with attributes |
426 | // containing those constants. |
427 | if (integers.empty()) |
428 | return failure(); |
429 | |
430 | auto range = llvm::map_range(C&: integers, F: [this](int64_t i) { |
431 | return IntegerAttr::get(IndexType::get(getContext()), i); |
432 | }); |
433 | results.append(in_start: range.begin(), in_end: range.end()); |
434 | return success(); |
435 | } |
436 | |
437 | AffineMap AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants, |
438 | SmallVectorImpl<int64_t> *results, |
439 | bool *hasPoison) const { |
440 | assert(getNumInputs() == operandConstants.size()); |
441 | |
442 | // Fold each of the result expressions. |
443 | AffineExprConstantFolder exprFolder(getNumDims(), operandConstants); |
444 | SmallVector<AffineExpr, 4> exprs; |
445 | exprs.reserve(N: getNumResults()); |
446 | |
447 | for (auto expr : getResults()) { |
448 | auto folded = exprFolder.constantFold(expr); |
449 | if (exprFolder.hasPoison() && hasPoison) { |
450 | *hasPoison = true; |
451 | return {}; |
452 | } |
453 | // If did not fold to a constant, keep the original expression, and clear |
454 | // the integer results vector. |
455 | if (folded) { |
456 | exprs.push_back( |
457 | Elt: getAffineConstantExpr(folded.getInt(), folded.getContext())); |
458 | if (results) |
459 | results->push_back(Elt: folded.getInt()); |
460 | } else { |
461 | exprs.push_back(Elt: expr); |
462 | if (results) { |
463 | results->clear(); |
464 | results = nullptr; |
465 | } |
466 | } |
467 | } |
468 | |
469 | return get(dimCount: getNumDims(), symbolCount: getNumSymbols(), results: exprs, context: getContext()); |
470 | } |
471 | |
472 | /// Walk all of the AffineExpr's in this mapping. Each node in an expression |
473 | /// tree is visited in postorder. |
474 | void AffineMap::walkExprs(llvm::function_ref<void(AffineExpr)> callback) const { |
475 | for (auto expr : getResults()) |
476 | expr.walk(callback); |
477 | } |
478 | |
479 | /// This method substitutes any uses of dimensions and symbols (e.g. |
480 | /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified |
481 | /// expression mapping. Because this can be used to eliminate dims and |
482 | /// symbols, the client needs to specify the number of dims and symbols in |
483 | /// the result. The returned map always has the same number of results. |
484 | AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements, |
485 | ArrayRef<AffineExpr> symReplacements, |
486 | unsigned numResultDims, |
487 | unsigned numResultSyms) const { |
488 | SmallVector<AffineExpr, 8> results; |
489 | results.reserve(N: getNumResults()); |
490 | for (auto expr : getResults()) |
491 | results.push_back( |
492 | Elt: expr.replaceDimsAndSymbols(dimReplacements, symReplacements)); |
493 | return get(dimCount: numResultDims, symbolCount: numResultSyms, results, context: getContext()); |
494 | } |
495 | |
496 | /// Sparse replace method. Apply AffineExpr::replace(`expr`, `replacement`) to |
497 | /// each of the results and return a new AffineMap with the new results and |
498 | /// with the specified number of dims and symbols. |
499 | AffineMap AffineMap::replace(AffineExpr expr, AffineExpr replacement, |
500 | unsigned numResultDims, |
501 | unsigned numResultSyms) const { |
502 | SmallVector<AffineExpr, 4> newResults; |
503 | newResults.reserve(N: getNumResults()); |
504 | for (AffineExpr e : getResults()) |
505 | newResults.push_back(Elt: e.replace(expr, replacement)); |
506 | return AffineMap::get(dimCount: numResultDims, symbolCount: numResultSyms, results: newResults, context: getContext()); |
507 | } |
508 | |
509 | /// Sparse replace method. Apply AffineExpr::replace(`map`) to each of the |
510 | /// results and return a new AffineMap with the new results and with the |
511 | /// specified number of dims and symbols. |
512 | AffineMap AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map, |
513 | unsigned numResultDims, |
514 | unsigned numResultSyms) const { |
515 | SmallVector<AffineExpr, 4> newResults; |
516 | newResults.reserve(N: getNumResults()); |
517 | for (AffineExpr e : getResults()) |
518 | newResults.push_back(Elt: e.replace(map)); |
519 | return AffineMap::get(dimCount: numResultDims, symbolCount: numResultSyms, results: newResults, context: getContext()); |
520 | } |
521 | |
522 | AffineMap |
523 | AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map) const { |
524 | SmallVector<AffineExpr, 4> newResults; |
525 | newResults.reserve(N: getNumResults()); |
526 | for (AffineExpr e : getResults()) |
527 | newResults.push_back(Elt: e.replace(map)); |
528 | return AffineMap::inferFromExprList(exprsList: newResults, context: getContext()).front(); |
529 | } |
530 | |
531 | AffineMap AffineMap::dropResults(const llvm::SmallBitVector &positions) const { |
532 | auto exprs = llvm::to_vector<4>(Range: getResults()); |
533 | // TODO: this is a pretty terrible API .. is there anything better? |
534 | for (auto pos = positions.find_last(); pos != -1; |
535 | pos = positions.find_prev(PriorTo: pos)) |
536 | exprs.erase(CI: exprs.begin() + pos); |
537 | return AffineMap::get(dimCount: getNumDims(), symbolCount: getNumSymbols(), results: exprs, context: getContext()); |
538 | } |
539 | |
540 | AffineMap AffineMap::compose(AffineMap map) const { |
541 | assert(getNumDims() == map.getNumResults() && "Number of results mismatch" ); |
542 | // Prepare `map` by concatenating the symbols and rewriting its exprs. |
543 | unsigned numDims = map.getNumDims(); |
544 | unsigned numSymbolsThisMap = getNumSymbols(); |
545 | unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols(); |
546 | SmallVector<AffineExpr, 8> newDims(numDims); |
547 | for (unsigned idx = 0; idx < numDims; ++idx) { |
548 | newDims[idx] = getAffineDimExpr(position: idx, context: getContext()); |
549 | } |
550 | SmallVector<AffineExpr, 8> newSymbols(numSymbols - numSymbolsThisMap); |
551 | for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) { |
552 | newSymbols[idx - numSymbolsThisMap] = |
553 | getAffineSymbolExpr(position: idx, context: getContext()); |
554 | } |
555 | auto newMap = |
556 | map.replaceDimsAndSymbols(dimReplacements: newDims, symReplacements: newSymbols, numResultDims: numDims, numResultSyms: numSymbols); |
557 | SmallVector<AffineExpr, 8> exprs; |
558 | exprs.reserve(N: getResults().size()); |
559 | for (auto expr : getResults()) |
560 | exprs.push_back(Elt: expr.compose(map: newMap)); |
561 | return AffineMap::get(dimCount: numDims, symbolCount: numSymbols, results: exprs, context: map.getContext()); |
562 | } |
563 | |
564 | SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const { |
565 | assert(getNumSymbols() == 0 && "Expected symbol-less map" ); |
566 | SmallVector<AffineExpr, 4> exprs; |
567 | exprs.reserve(N: values.size()); |
568 | MLIRContext *ctx = getContext(); |
569 | for (auto v : values) |
570 | exprs.push_back(Elt: getAffineConstantExpr(constant: v, context: ctx)); |
571 | auto resMap = compose(map: AffineMap::get(dimCount: 0, symbolCount: 0, results: exprs, context: ctx)); |
572 | SmallVector<int64_t, 4> res; |
573 | res.reserve(N: resMap.getNumResults()); |
574 | for (auto e : resMap.getResults()) |
575 | res.push_back(Elt: cast<AffineConstantExpr>(Val&: e).getValue()); |
576 | return res; |
577 | } |
578 | |
579 | bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const { |
580 | if (getNumSymbols() > 0) |
581 | return false; |
582 | |
583 | // Having more results than inputs means that results have duplicated dims or |
584 | // zeros that can't be mapped to input dims. |
585 | if (getNumResults() > getNumInputs()) |
586 | return false; |
587 | |
588 | SmallVector<bool, 8> seen(getNumInputs(), false); |
589 | // A projected permutation can have, at most, only one instance of each input |
590 | // dimension in the result expressions. Zeros are allowed as long as the |
591 | // number of result expressions is lower or equal than the number of input |
592 | // expressions. |
593 | for (auto expr : getResults()) { |
594 | if (auto dim = dyn_cast<AffineDimExpr>(Val&: expr)) { |
595 | if (seen[dim.getPosition()]) |
596 | return false; |
597 | seen[dim.getPosition()] = true; |
598 | } else { |
599 | auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr); |
600 | if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0) |
601 | return false; |
602 | } |
603 | } |
604 | |
605 | // Results are either dims or zeros and zeros can be mapped to input dims. |
606 | return true; |
607 | } |
608 | |
609 | bool AffineMap::isPermutation() const { |
610 | if (getNumDims() != getNumResults()) |
611 | return false; |
612 | return isProjectedPermutation(); |
613 | } |
614 | |
615 | AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) const { |
616 | SmallVector<AffineExpr, 4> exprs; |
617 | exprs.reserve(N: resultPos.size()); |
618 | for (auto idx : resultPos) |
619 | exprs.push_back(Elt: getResult(idx)); |
620 | return AffineMap::get(dimCount: getNumDims(), symbolCount: getNumSymbols(), results: exprs, context: getContext()); |
621 | } |
622 | |
623 | AffineMap AffineMap::getSliceMap(unsigned start, unsigned length) const { |
624 | return AffineMap::get(dimCount: getNumDims(), symbolCount: getNumSymbols(), |
625 | results: getResults().slice(N: start, M: length), context: getContext()); |
626 | } |
627 | |
628 | AffineMap AffineMap::getMajorSubMap(unsigned numResults) const { |
629 | if (numResults == 0) |
630 | return AffineMap(); |
631 | if (numResults > getNumResults()) |
632 | return *this; |
633 | return getSliceMap(start: 0, length: numResults); |
634 | } |
635 | |
636 | AffineMap AffineMap::getMinorSubMap(unsigned numResults) const { |
637 | if (numResults == 0) |
638 | return AffineMap(); |
639 | if (numResults > getNumResults()) |
640 | return *this; |
641 | return getSliceMap(start: getNumResults() - numResults, length: numResults); |
642 | } |
643 | |
644 | /// Implementation detail to compress multiple affine maps with a compressionFun |
645 | /// that is expected to be either compressUnusedDims or compressUnusedSymbols. |
646 | /// The implementation keeps track of num dims and symbols across the different |
647 | /// affine maps. |
648 | static SmallVector<AffineMap> compressUnusedListImpl( |
649 | ArrayRef<AffineMap> maps, |
650 | llvm::function_ref<AffineMap(AffineMap)> compressionFun) { |
651 | if (maps.empty()) |
652 | return SmallVector<AffineMap>(); |
653 | SmallVector<AffineExpr> allExprs; |
654 | allExprs.reserve(N: maps.size() * maps.front().getNumResults()); |
655 | unsigned numDims = maps.front().getNumDims(), |
656 | numSymbols = maps.front().getNumSymbols(); |
657 | for (auto m : maps) { |
658 | assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() && |
659 | "expected maps with same num dims and symbols" ); |
660 | llvm::append_range(C&: allExprs, R: m.getResults()); |
661 | } |
662 | AffineMap unifiedMap = compressionFun( |
663 | AffineMap::get(dimCount: numDims, symbolCount: numSymbols, results: allExprs, context: maps.front().getContext())); |
664 | unsigned unifiedNumDims = unifiedMap.getNumDims(), |
665 | unifiedNumSymbols = unifiedMap.getNumSymbols(); |
666 | ArrayRef<AffineExpr> unifiedResults = unifiedMap.getResults(); |
667 | SmallVector<AffineMap> res; |
668 | res.reserve(N: maps.size()); |
669 | for (auto m : maps) { |
670 | res.push_back(Elt: AffineMap::get(dimCount: unifiedNumDims, symbolCount: unifiedNumSymbols, |
671 | results: unifiedResults.take_front(N: m.getNumResults()), |
672 | context: m.getContext())); |
673 | unifiedResults = unifiedResults.drop_front(N: m.getNumResults()); |
674 | } |
675 | return res; |
676 | } |
677 | |
678 | AffineMap mlir::compressDims(AffineMap map, |
679 | const llvm::SmallBitVector &unusedDims) { |
680 | return projectDims(map, projectedDimensions: unusedDims, /*compressDimsFlag=*/true); |
681 | } |
682 | |
683 | AffineMap mlir::compressUnusedDims(AffineMap map) { |
684 | return compressDims(map, unusedDims: getUnusedDimsBitVector(maps: {map})); |
685 | } |
686 | |
687 | SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) { |
688 | return compressUnusedListImpl( |
689 | maps, compressionFun: [](AffineMap m) { return compressUnusedDims(map: m); }); |
690 | } |
691 | |
692 | AffineMap mlir::compressSymbols(AffineMap map, |
693 | const llvm::SmallBitVector &unusedSymbols) { |
694 | return projectSymbols(map, projectedSymbols: unusedSymbols, /*compressSymbolsFlag=*/true); |
695 | } |
696 | |
697 | AffineMap mlir::compressUnusedSymbols(AffineMap map) { |
698 | return compressSymbols(map, unusedSymbols: getUnusedSymbolsBitVector(maps: {map})); |
699 | } |
700 | |
701 | SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) { |
702 | return compressUnusedListImpl( |
703 | maps, compressionFun: [](AffineMap m) { return compressUnusedSymbols(map: m); }); |
704 | } |
705 | |
706 | AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map, |
707 | ArrayRef<OpFoldResult> operands, |
708 | SmallVector<Value> &remainingValues) { |
709 | SmallVector<AffineExpr> dimReplacements, symReplacements; |
710 | int64_t numDims = 0; |
711 | for (int64_t i = 0; i < map.getNumDims(); ++i) { |
712 | if (auto attr = operands[i].dyn_cast<Attribute>()) { |
713 | dimReplacements.push_back( |
714 | Elt: b.getAffineConstantExpr(constant: cast<IntegerAttr>(attr).getInt())); |
715 | } else { |
716 | dimReplacements.push_back(Elt: b.getAffineDimExpr(position: numDims++)); |
717 | remainingValues.push_back(Elt: operands[i].get<Value>()); |
718 | } |
719 | } |
720 | int64_t numSymbols = 0; |
721 | for (int64_t i = 0; i < map.getNumSymbols(); ++i) { |
722 | if (auto attr = operands[i + map.getNumDims()].dyn_cast<Attribute>()) { |
723 | symReplacements.push_back( |
724 | Elt: b.getAffineConstantExpr(constant: cast<IntegerAttr>(attr).getInt())); |
725 | } else { |
726 | symReplacements.push_back(Elt: b.getAffineSymbolExpr(position: numSymbols++)); |
727 | remainingValues.push_back(Elt: operands[i + map.getNumDims()].get<Value>()); |
728 | } |
729 | } |
730 | return map.replaceDimsAndSymbols(dimReplacements, symReplacements, numResultDims: numDims, |
731 | numResultSyms: numSymbols); |
732 | } |
733 | |
734 | AffineMap mlir::simplifyAffineMap(AffineMap map) { |
735 | SmallVector<AffineExpr, 8> exprs; |
736 | for (auto e : map.getResults()) { |
737 | exprs.push_back( |
738 | Elt: simplifyAffineExpr(expr: e, numDims: map.getNumDims(), numSymbols: map.getNumSymbols())); |
739 | } |
740 | return AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: exprs, |
741 | context: map.getContext()); |
742 | } |
743 | |
744 | AffineMap mlir::removeDuplicateExprs(AffineMap map) { |
745 | auto results = map.getResults(); |
746 | SmallVector<AffineExpr, 4> uniqueExprs(results.begin(), results.end()); |
747 | uniqueExprs.erase(CS: std::unique(first: uniqueExprs.begin(), last: uniqueExprs.end()), |
748 | CE: uniqueExprs.end()); |
749 | return AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: uniqueExprs, |
750 | context: map.getContext()); |
751 | } |
752 | |
753 | AffineMap mlir::inversePermutation(AffineMap map) { |
754 | if (map.isEmpty()) |
755 | return map; |
756 | assert(map.getNumSymbols() == 0 && "expected map without symbols" ); |
757 | SmallVector<AffineExpr, 4> exprs(map.getNumDims()); |
758 | for (const auto &en : llvm::enumerate(First: map.getResults())) { |
759 | auto expr = en.value(); |
760 | // Skip non-permutations. |
761 | if (auto d = dyn_cast<AffineDimExpr>(Val&: expr)) { |
762 | if (exprs[d.getPosition()]) |
763 | continue; |
764 | exprs[d.getPosition()] = getAffineDimExpr(position: en.index(), context: d.getContext()); |
765 | } |
766 | } |
767 | SmallVector<AffineExpr, 4> seenExprs; |
768 | seenExprs.reserve(N: map.getNumDims()); |
769 | for (auto expr : exprs) |
770 | if (expr) |
771 | seenExprs.push_back(Elt: expr); |
772 | if (seenExprs.size() != map.getNumInputs()) |
773 | return AffineMap(); |
774 | return AffineMap::get(dimCount: map.getNumResults(), symbolCount: 0, results: seenExprs, context: map.getContext()); |
775 | } |
776 | |
777 | AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) { |
778 | assert(map.isProjectedPermutation(/*allowZeroInResults=*/true)); |
779 | MLIRContext *context = map.getContext(); |
780 | AffineExpr zero = mlir::getAffineConstantExpr(constant: 0, context); |
781 | // Start with all the results as 0. |
782 | SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero); |
783 | for (unsigned i : llvm::seq(Begin: unsigned(0), End: map.getNumResults())) { |
784 | // Skip zeros from input map. 'exprs' is already initialized to zero. |
785 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val: map.getResult(idx: i))) { |
786 | assert(constExpr.getValue() == 0 && |
787 | "Unexpected constant in projected permutation" ); |
788 | (void)constExpr; |
789 | continue; |
790 | } |
791 | |
792 | // Reverse each dimension existing in the original map result. |
793 | exprs[map.getDimPosition(idx: i)] = getAffineDimExpr(position: i, context); |
794 | } |
795 | return AffineMap::get(dimCount: map.getNumResults(), /*symbolCount=*/0, results: exprs, context); |
796 | } |
797 | |
798 | AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) { |
799 | unsigned numResults = 0, numDims = 0, numSymbols = 0; |
800 | for (auto m : maps) |
801 | numResults += m.getNumResults(); |
802 | SmallVector<AffineExpr, 8> results; |
803 | results.reserve(N: numResults); |
804 | for (auto m : maps) { |
805 | for (auto res : m.getResults()) |
806 | results.push_back(Elt: res.shiftSymbols(numSymbols: m.getNumSymbols(), shift: numSymbols)); |
807 | |
808 | numSymbols += m.getNumSymbols(); |
809 | numDims = std::max(a: m.getNumDims(), b: numDims); |
810 | } |
811 | return AffineMap::get(dimCount: numDims, symbolCount: numSymbols, results, |
812 | context: maps.front().getContext()); |
813 | } |
814 | |
815 | /// Common implementation to project out dimensions or symbols from an affine |
816 | /// map based on the template type. |
817 | /// Additionally, if 'compress' is true, the projected out dimensions or symbols |
818 | /// are also dropped from the resulting map. |
819 | template <typename AffineDimOrSymExpr> |
820 | static AffineMap projectCommonImpl(AffineMap map, |
821 | const llvm::SmallBitVector &toProject, |
822 | bool compress) { |
823 | static_assert(llvm::is_one_of<AffineDimOrSymExpr, AffineDimExpr, |
824 | AffineSymbolExpr>::value, |
825 | "expected AffineDimExpr or AffineSymbolExpr" ); |
826 | |
827 | constexpr bool isDim = std::is_same<AffineDimOrSymExpr, AffineDimExpr>::value; |
828 | int64_t numDimOrSym = (isDim) ? map.getNumDims() : map.getNumSymbols(); |
829 | SmallVector<AffineExpr> replacements; |
830 | replacements.reserve(N: numDimOrSym); |
831 | |
832 | auto createNewDimOrSym = (isDim) ? getAffineDimExpr : getAffineSymbolExpr; |
833 | |
834 | using replace_fn_ty = |
835 | std::function<AffineExpr(AffineExpr, ArrayRef<AffineExpr>)>; |
836 | replace_fn_ty replaceDims = [](AffineExpr e, |
837 | ArrayRef<AffineExpr> replacements) { |
838 | return e.replaceDims(dimReplacements: replacements); |
839 | }; |
840 | replace_fn_ty replaceSymbols = [](AffineExpr e, |
841 | ArrayRef<AffineExpr> replacements) { |
842 | return e.replaceSymbols(symReplacements: replacements); |
843 | }; |
844 | replace_fn_ty replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols; |
845 | |
846 | MLIRContext *context = map.getContext(); |
847 | int64_t newNumDimOrSym = 0; |
848 | for (unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) { |
849 | if (toProject.test(Idx: dimOrSym)) { |
850 | replacements.push_back(Elt: getAffineConstantExpr(constant: 0, context)); |
851 | continue; |
852 | } |
853 | int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym; |
854 | replacements.push_back(Elt: createNewDimOrSym(newPos, context)); |
855 | } |
856 | SmallVector<AffineExpr> resultExprs; |
857 | resultExprs.reserve(N: map.getNumResults()); |
858 | for (auto e : map.getResults()) |
859 | resultExprs.push_back(Elt: replaceNewDimOrSym(e, replacements)); |
860 | |
861 | int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.getNumDims(); |
862 | int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.getNumSymbols(); |
863 | return AffineMap::get(dimCount: numDims, symbolCount: numSyms, results: resultExprs, context); |
864 | } |
865 | |
866 | AffineMap mlir::projectDims(AffineMap map, |
867 | const llvm::SmallBitVector &projectedDimensions, |
868 | bool compressDimsFlag) { |
869 | return projectCommonImpl<AffineDimExpr>(map, toProject: projectedDimensions, |
870 | compress: compressDimsFlag); |
871 | } |
872 | |
873 | AffineMap mlir::projectSymbols(AffineMap map, |
874 | const llvm::SmallBitVector &projectedSymbols, |
875 | bool compressSymbolsFlag) { |
876 | return projectCommonImpl<AffineSymbolExpr>(map, toProject: projectedSymbols, |
877 | compress: compressSymbolsFlag); |
878 | } |
879 | |
880 | AffineMap mlir::getProjectedMap(AffineMap map, |
881 | const llvm::SmallBitVector &projectedDimensions, |
882 | bool compressDimsFlag, |
883 | bool compressSymbolsFlag) { |
884 | map = projectDims(map, projectedDimensions, compressDimsFlag); |
885 | if (compressSymbolsFlag) |
886 | map = compressUnusedSymbols(map); |
887 | return map; |
888 | } |
889 | |
890 | llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) { |
891 | unsigned numDims = maps[0].getNumDims(); |
892 | llvm::SmallBitVector numDimsBitVector(numDims, true); |
893 | for (AffineMap m : maps) { |
894 | for (unsigned i = 0; i < numDims; ++i) { |
895 | if (m.isFunctionOfDim(position: i)) |
896 | numDimsBitVector.reset(Idx: i); |
897 | } |
898 | } |
899 | return numDimsBitVector; |
900 | } |
901 | |
902 | llvm::SmallBitVector mlir::getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps) { |
903 | unsigned numSymbols = maps[0].getNumSymbols(); |
904 | llvm::SmallBitVector numSymbolsBitVector(numSymbols, true); |
905 | for (AffineMap m : maps) { |
906 | for (unsigned i = 0; i < numSymbols; ++i) { |
907 | if (m.isFunctionOfSymbol(position: i)) |
908 | numSymbolsBitVector.reset(Idx: i); |
909 | } |
910 | } |
911 | return numSymbolsBitVector; |
912 | } |
913 | |
914 | AffineMap |
915 | mlir::expandDimsToRank(AffineMap map, int64_t rank, |
916 | const llvm::SmallBitVector &projectedDimensions) { |
917 | auto id = AffineMap::getMultiDimIdentityMap(numDims: rank, context: map.getContext()); |
918 | AffineMap proj = id.dropResults(positions: projectedDimensions); |
919 | return map.compose(map: proj); |
920 | } |
921 | |
922 | //===----------------------------------------------------------------------===// |
923 | // MutableAffineMap. |
924 | //===----------------------------------------------------------------------===// |
925 | |
926 | MutableAffineMap::MutableAffineMap(AffineMap map) |
927 | : results(map.getResults().begin(), map.getResults().end()), |
928 | numDims(map.getNumDims()), numSymbols(map.getNumSymbols()), |
929 | context(map.getContext()) {} |
930 | |
931 | void MutableAffineMap::reset(AffineMap map) { |
932 | results.clear(); |
933 | numDims = map.getNumDims(); |
934 | numSymbols = map.getNumSymbols(); |
935 | context = map.getContext(); |
936 | llvm::append_range(C&: results, R: map.getResults()); |
937 | } |
938 | |
939 | bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { |
940 | return results[idx].isMultipleOf(factor); |
941 | } |
942 | |
943 | // Simplifies the result affine expressions of this map. The expressions |
944 | // have to be pure for the simplification implemented. |
945 | void MutableAffineMap::simplify() { |
946 | // Simplify each of the results if possible. |
947 | // TODO: functional-style map |
948 | for (unsigned i = 0, e = getNumResults(); i < e; i++) { |
949 | results[i] = simplifyAffineExpr(expr: getResult(idx: i), numDims, numSymbols); |
950 | } |
951 | } |
952 | |
953 | AffineMap MutableAffineMap::getAffineMap() const { |
954 | return AffineMap::get(dimCount: numDims, symbolCount: numSymbols, results, context); |
955 | } |
956 | |