1//===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/AffineMap.h"
10#include "AffineMapDetail.h"
11#include "mlir/IR/AffineExpr.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinAttributes.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "mlir/Support/LogicalResult.h"
16#include "mlir/Support/MathExtras.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SmallBitVector.h"
19#include "llvm/ADT/SmallSet.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/ADT/StringRef.h"
22#include "llvm/Support/raw_ostream.h"
23#include <iterator>
24#include <numeric>
25#include <optional>
26#include <type_traits>
27
28using namespace mlir;
29
30namespace {
31
32// AffineExprConstantFolder evaluates an affine expression using constant
33// operands passed in 'operandConsts'. Returns an IntegerAttr attribute
34// representing the constant value of the affine expression evaluated on
35// constant 'operandConsts', or nullptr if it can't be folded.
36class AffineExprConstantFolder {
37public:
38 AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
39 : numDims(numDims), operandConsts(operandConsts) {}
40
41 /// Attempt to constant fold the specified affine expr, or return null on
42 /// failure.
43 IntegerAttr constantFold(AffineExpr expr) {
44 if (auto result = constantFoldImpl(expr))
45 return IntegerAttr::get(IndexType::get(expr.getContext()), *result);
46 return nullptr;
47 }
48
49 bool hasPoison() const { return hasPoison_; }
50
51private:
52 std::optional<int64_t> constantFoldImpl(AffineExpr expr) {
53 switch (expr.getKind()) {
54 case AffineExprKind::Add:
55 return constantFoldBinExpr(
56 expr, op: [](int64_t lhs, int64_t rhs) { return lhs + rhs; });
57 case AffineExprKind::Mul:
58 return constantFoldBinExpr(
59 expr, op: [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
60 case AffineExprKind::Mod:
61 return constantFoldBinExpr(
62 expr, op: [this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
63 if (rhs < 1) {
64 hasPoison_ = true;
65 return std::nullopt;
66 }
67 return mod(lhs, rhs);
68 });
69 case AffineExprKind::FloorDiv:
70 return constantFoldBinExpr(
71 expr, op: [this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
72 if (rhs == 0) {
73 hasPoison_ = true;
74 return std::nullopt;
75 }
76 return floorDiv(lhs, rhs);
77 });
78 case AffineExprKind::CeilDiv:
79 return constantFoldBinExpr(
80 expr, op: [this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
81 if (rhs == 0) {
82 hasPoison_ = true;
83 return std::nullopt;
84 }
85 return ceilDiv(lhs, rhs);
86 });
87 case AffineExprKind::Constant:
88 return cast<AffineConstantExpr>(Val&: expr).getValue();
89 case AffineExprKind::DimId:
90 if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
91 operandConsts[cast<AffineDimExpr>(expr).getPosition()]))
92 return attr.getInt();
93 return std::nullopt;
94 case AffineExprKind::SymbolId:
95 if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
96 operandConsts[numDims +
97 cast<AffineSymbolExpr>(expr).getPosition()]))
98 return attr.getInt();
99 return std::nullopt;
100 }
101 llvm_unreachable("Unknown AffineExpr");
102 }
103
104 // TODO: Change these to operate on APInts too.
105 std::optional<int64_t> constantFoldBinExpr(
106 AffineExpr expr,
107 llvm::function_ref<std::optional<int64_t>(int64_t, int64_t)> op) {
108 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
109 if (auto lhs = constantFoldImpl(expr: binOpExpr.getLHS()))
110 if (auto rhs = constantFoldImpl(expr: binOpExpr.getRHS()))
111 return op(*lhs, *rhs);
112 return std::nullopt;
113 }
114
115 // The number of dimension operands in AffineMap containing this expression.
116 unsigned numDims;
117 // The constant valued operands used to evaluate this AffineExpr.
118 ArrayRef<Attribute> operandConsts;
119 bool hasPoison_{false};
120};
121
122} // namespace
123
124/// Returns a single constant result affine map.
125AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
126 return get(/*dimCount=*/0, /*symbolCount=*/0,
127 result: {getAffineConstantExpr(constant: val, context)});
128}
129
130/// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most
131/// minor dimensions.
132AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
133 MLIRContext *context) {
134 assert(dims >= results && "Dimension mismatch");
135 auto id = AffineMap::getMultiDimIdentityMap(numDims: dims, context);
136 return AffineMap::get(dimCount: dims, symbolCount: 0, results: id.getResults().take_back(N: results), context);
137}
138
139AffineMap AffineMap::getFilteredIdentityMap(
140 MLIRContext *ctx, unsigned numDims,
141 llvm::function_ref<bool(AffineDimExpr)> keepDimFilter) {
142 auto identityMap = getMultiDimIdentityMap(numDims, context: ctx);
143
144 // Apply filter to results.
145 llvm::SmallBitVector dropDimResults(numDims);
146 for (auto [idx, resultExpr] : llvm::enumerate(First: identityMap.getResults()))
147 dropDimResults[idx] = !keepDimFilter(cast<AffineDimExpr>(Val: resultExpr));
148
149 return identityMap.dropResults(positions: dropDimResults);
150}
151
152bool AffineMap::isMinorIdentity() const {
153 return getNumDims() >= getNumResults() &&
154 *this ==
155 getMinorIdentityMap(dims: getNumDims(), results: getNumResults(), context: getContext());
156}
157
158/// Returns true if this affine map is a minor identity up to broadcasted
159/// dimensions which are indicated by value 0 in the result.
160bool AffineMap::isMinorIdentityWithBroadcasting(
161 SmallVectorImpl<unsigned> *broadcastedDims) const {
162 if (broadcastedDims)
163 broadcastedDims->clear();
164 if (getNumDims() < getNumResults())
165 return false;
166 unsigned suffixStart = getNumDims() - getNumResults();
167 for (const auto &idxAndExpr : llvm::enumerate(First: getResults())) {
168 unsigned resIdx = idxAndExpr.index();
169 AffineExpr expr = idxAndExpr.value();
170 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) {
171 // Each result may be either a constant 0 (broadcasted dimension).
172 if (constExpr.getValue() != 0)
173 return false;
174 if (broadcastedDims)
175 broadcastedDims->push_back(Elt: resIdx);
176 } else if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) {
177 // Or it may be the input dimension corresponding to this result position.
178 if (dimExpr.getPosition() != suffixStart + resIdx)
179 return false;
180 } else {
181 return false;
182 }
183 }
184 return true;
185}
186
187/// Return true if this affine map can be converted to a minor identity with
188/// broadcast by doing a permute. Return a permutation (there may be
189/// several) to apply to get to a minor identity with broadcasts.
190/// Ex:
191/// * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with
192/// perm = [1, 0] and broadcast d2
193/// * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by
194/// permutation + broadcast
195/// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3)
196/// with perm = [1, 0, 2] and broadcast d2
197/// * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra
198/// leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with
199/// perm = [3, 0, 1, 2]
200bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting(
201 SmallVectorImpl<unsigned> &permutedDims) const {
202 unsigned projectionStart =
203 getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0;
204 permutedDims.clear();
205 SmallVector<unsigned> broadcastDims;
206 permutedDims.resize(N: getNumResults(), NV: 0);
207 // If there are more results than input dimensions we want the new map to
208 // start with broadcast dimensions in order to be a minor identity with
209 // broadcasting.
210 unsigned leadingBroadcast =
211 getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0;
212 llvm::SmallBitVector dimFound(std::max(a: getNumInputs(), b: getNumResults()),
213 false);
214 for (const auto &idxAndExpr : llvm::enumerate(First: getResults())) {
215 unsigned resIdx = idxAndExpr.index();
216 AffineExpr expr = idxAndExpr.value();
217 // Each result may be either a constant 0 (broadcast dimension) or a
218 // dimension.
219 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) {
220 if (constExpr.getValue() != 0)
221 return false;
222 broadcastDims.push_back(Elt: resIdx);
223 } else if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) {
224 if (dimExpr.getPosition() < projectionStart)
225 return false;
226 unsigned newPosition =
227 dimExpr.getPosition() - projectionStart + leadingBroadcast;
228 permutedDims[resIdx] = newPosition;
229 dimFound[newPosition] = true;
230 } else {
231 return false;
232 }
233 }
234 // Find a permuation for the broadcast dimension. Since they are broadcasted
235 // any valid permutation is acceptable. We just permute the dim into a slot
236 // without an existing dimension.
237 unsigned pos = 0;
238 for (auto dim : broadcastDims) {
239 while (pos < dimFound.size() && dimFound[pos]) {
240 pos++;
241 }
242 permutedDims[dim] = pos++;
243 }
244 return true;
245}
246
247/// Returns an AffineMap representing a permutation.
248AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
249 MLIRContext *context) {
250 assert(!permutation.empty() &&
251 "Cannot create permutation map from empty permutation vector");
252 const auto *m = llvm::max_element(Range&: permutation);
253 auto permutationMap = getMultiDimMapWithTargets(numDims: *m + 1, targets: permutation, context);
254 assert(permutationMap.isPermutation() && "Invalid permutation vector");
255 return permutationMap;
256}
257AffineMap AffineMap::getPermutationMap(ArrayRef<int64_t> permutation,
258 MLIRContext *context) {
259 SmallVector<unsigned> perm = llvm::map_to_vector(
260 C&: permutation, F: [](int64_t i) { return static_cast<unsigned>(i); });
261 return AffineMap::getPermutationMap(permutation: perm, context);
262}
263
264AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
265 ArrayRef<unsigned> targets,
266 MLIRContext *context) {
267 SmallVector<AffineExpr, 4> affExprs;
268 for (unsigned t : targets)
269 affExprs.push_back(Elt: getAffineDimExpr(position: t, context));
270 AffineMap result = AffineMap::get(/*dimCount=*/numDims, /*symbolCount=*/0,
271 results: affExprs, context);
272 return result;
273}
274
275/// Creates an affine map each for each list of AffineExpr's in `exprsList`
276/// while inferring the right number of dimensional and symbolic inputs needed
277/// based on the maximum dimensional and symbolic identifier appearing in the
278/// expressions.
279template <typename AffineExprContainer>
280static SmallVector<AffineMap, 4>
281inferFromExprList(ArrayRef<AffineExprContainer> exprsList,
282 MLIRContext *context) {
283 if (exprsList.empty())
284 return {};
285 int64_t maxDim = -1, maxSym = -1;
286 getMaxDimAndSymbol(exprsList, maxDim, maxSym);
287 SmallVector<AffineMap, 4> maps;
288 maps.reserve(N: exprsList.size());
289 for (const auto &exprs : exprsList)
290 maps.push_back(Elt: AffineMap::get(/*dimCount=*/maxDim + 1,
291 /*symbolCount=*/maxSym + 1, exprs, context));
292 return maps;
293}
294
295SmallVector<AffineMap, 4>
296AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList,
297 MLIRContext *context) {
298 return ::inferFromExprList(exprsList, context);
299}
300
301SmallVector<AffineMap, 4>
302AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList,
303 MLIRContext *context) {
304 return ::inferFromExprList(exprsList, context);
305}
306
307uint64_t AffineMap::getLargestKnownDivisorOfMapExprs() {
308 uint64_t gcd = 0;
309 for (AffineExpr resultExpr : getResults()) {
310 uint64_t thisGcd = resultExpr.getLargestKnownDivisor();
311 gcd = std::gcd(m: gcd, n: thisGcd);
312 }
313 if (gcd == 0)
314 gcd = std::numeric_limits<uint64_t>::max();
315 return gcd;
316}
317
318AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
319 MLIRContext *context) {
320 SmallVector<AffineExpr, 4> dimExprs;
321 dimExprs.reserve(N: numDims);
322 for (unsigned i = 0; i < numDims; ++i)
323 dimExprs.push_back(Elt: mlir::getAffineDimExpr(position: i, context));
324 return get(/*dimCount=*/numDims, /*symbolCount=*/0, results: dimExprs, context);
325}
326
327MLIRContext *AffineMap::getContext() const { return map->context; }
328
329bool AffineMap::isIdentity() const {
330 if (getNumDims() != getNumResults())
331 return false;
332 ArrayRef<AffineExpr> results = getResults();
333 for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
334 auto expr = dyn_cast<AffineDimExpr>(Val: results[i]);
335 if (!expr || expr.getPosition() != i)
336 return false;
337 }
338 return true;
339}
340
341bool AffineMap::isSymbolIdentity() const {
342 if (getNumSymbols() != getNumResults())
343 return false;
344 ArrayRef<AffineExpr> results = getResults();
345 for (unsigned i = 0, numSymbols = getNumSymbols(); i < numSymbols; ++i) {
346 auto expr = dyn_cast<AffineDimExpr>(Val: results[i]);
347 if (!expr || expr.getPosition() != i)
348 return false;
349 }
350 return true;
351}
352
353bool AffineMap::isEmpty() const {
354 return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0;
355}
356
357bool AffineMap::isSingleConstant() const {
358 return getNumResults() == 1 && isa<AffineConstantExpr>(Val: getResult(idx: 0));
359}
360
361bool AffineMap::isConstant() const {
362 return llvm::all_of(Range: getResults(), P: llvm::IsaPred<AffineConstantExpr>);
363}
364
365int64_t AffineMap::getSingleConstantResult() const {
366 assert(isSingleConstant() && "map must have a single constant result");
367 return cast<AffineConstantExpr>(Val: getResult(idx: 0)).getValue();
368}
369
370SmallVector<int64_t> AffineMap::getConstantResults() const {
371 assert(isConstant() && "map must have only constant results");
372 SmallVector<int64_t> result;
373 for (auto expr : getResults())
374 result.emplace_back(Args: cast<AffineConstantExpr>(Val&: expr).getValue());
375 return result;
376}
377
378unsigned AffineMap::getNumDims() const {
379 assert(map && "uninitialized map storage");
380 return map->numDims;
381}
382unsigned AffineMap::getNumSymbols() const {
383 assert(map && "uninitialized map storage");
384 return map->numSymbols;
385}
386unsigned AffineMap::getNumResults() const { return getResults().size(); }
387unsigned AffineMap::getNumInputs() const {
388 assert(map && "uninitialized map storage");
389 return map->numDims + map->numSymbols;
390}
391ArrayRef<AffineExpr> AffineMap::getResults() const {
392 assert(map && "uninitialized map storage");
393 return map->results();
394}
395AffineExpr AffineMap::getResult(unsigned idx) const {
396 return getResults()[idx];
397}
398
399unsigned AffineMap::getDimPosition(unsigned idx) const {
400 return cast<AffineDimExpr>(Val: getResult(idx)).getPosition();
401}
402
403std::optional<unsigned> AffineMap::getResultPosition(AffineExpr input) const {
404 if (!isa<AffineDimExpr>(Val: input))
405 return std::nullopt;
406
407 for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) {
408 if (getResult(idx: i) == input)
409 return i;
410 }
411
412 return std::nullopt;
413}
414
415/// Folds the results of the application of an affine map on the provided
416/// operands to a constant if possible. Returns false if the folding happens,
417/// true otherwise.
418LogicalResult AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
419 SmallVectorImpl<Attribute> &results,
420 bool *hasPoison) const {
421 // Attempt partial folding.
422 SmallVector<int64_t, 2> integers;
423 partialConstantFold(operandConstants, results: &integers, hasPoison);
424
425 // If all expressions folded to a constant, populate results with attributes
426 // containing those constants.
427 if (integers.empty())
428 return failure();
429
430 auto range = llvm::map_range(C&: integers, F: [this](int64_t i) {
431 return IntegerAttr::get(IndexType::get(getContext()), i);
432 });
433 results.append(in_start: range.begin(), in_end: range.end());
434 return success();
435}
436
437AffineMap AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
438 SmallVectorImpl<int64_t> *results,
439 bool *hasPoison) const {
440 assert(getNumInputs() == operandConstants.size());
441
442 // Fold each of the result expressions.
443 AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
444 SmallVector<AffineExpr, 4> exprs;
445 exprs.reserve(N: getNumResults());
446
447 for (auto expr : getResults()) {
448 auto folded = exprFolder.constantFold(expr);
449 if (exprFolder.hasPoison() && hasPoison) {
450 *hasPoison = true;
451 return {};
452 }
453 // If did not fold to a constant, keep the original expression, and clear
454 // the integer results vector.
455 if (folded) {
456 exprs.push_back(
457 Elt: getAffineConstantExpr(folded.getInt(), folded.getContext()));
458 if (results)
459 results->push_back(Elt: folded.getInt());
460 } else {
461 exprs.push_back(Elt: expr);
462 if (results) {
463 results->clear();
464 results = nullptr;
465 }
466 }
467 }
468
469 return get(dimCount: getNumDims(), symbolCount: getNumSymbols(), results: exprs, context: getContext());
470}
471
472/// Walk all of the AffineExpr's in this mapping. Each node in an expression
473/// tree is visited in postorder.
474void AffineMap::walkExprs(llvm::function_ref<void(AffineExpr)> callback) const {
475 for (auto expr : getResults())
476 expr.walk(callback);
477}
478
479/// This method substitutes any uses of dimensions and symbols (e.g.
480/// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
481/// expression mapping. Because this can be used to eliminate dims and
482/// symbols, the client needs to specify the number of dims and symbols in
483/// the result. The returned map always has the same number of results.
484AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
485 ArrayRef<AffineExpr> symReplacements,
486 unsigned numResultDims,
487 unsigned numResultSyms) const {
488 SmallVector<AffineExpr, 8> results;
489 results.reserve(N: getNumResults());
490 for (auto expr : getResults())
491 results.push_back(
492 Elt: expr.replaceDimsAndSymbols(dimReplacements, symReplacements));
493 return get(dimCount: numResultDims, symbolCount: numResultSyms, results, context: getContext());
494}
495
496/// Sparse replace method. Apply AffineExpr::replace(`expr`, `replacement`) to
497/// each of the results and return a new AffineMap with the new results and
498/// with the specified number of dims and symbols.
499AffineMap AffineMap::replace(AffineExpr expr, AffineExpr replacement,
500 unsigned numResultDims,
501 unsigned numResultSyms) const {
502 SmallVector<AffineExpr, 4> newResults;
503 newResults.reserve(N: getNumResults());
504 for (AffineExpr e : getResults())
505 newResults.push_back(Elt: e.replace(expr, replacement));
506 return AffineMap::get(dimCount: numResultDims, symbolCount: numResultSyms, results: newResults, context: getContext());
507}
508
509/// Sparse replace method. Apply AffineExpr::replace(`map`) to each of the
510/// results and return a new AffineMap with the new results and with the
511/// specified number of dims and symbols.
512AffineMap AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map,
513 unsigned numResultDims,
514 unsigned numResultSyms) const {
515 SmallVector<AffineExpr, 4> newResults;
516 newResults.reserve(N: getNumResults());
517 for (AffineExpr e : getResults())
518 newResults.push_back(Elt: e.replace(map));
519 return AffineMap::get(dimCount: numResultDims, symbolCount: numResultSyms, results: newResults, context: getContext());
520}
521
522AffineMap
523AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
524 SmallVector<AffineExpr, 4> newResults;
525 newResults.reserve(N: getNumResults());
526 for (AffineExpr e : getResults())
527 newResults.push_back(Elt: e.replace(map));
528 return AffineMap::inferFromExprList(exprsList: newResults, context: getContext()).front();
529}
530
531AffineMap AffineMap::dropResults(const llvm::SmallBitVector &positions) const {
532 auto exprs = llvm::to_vector<4>(Range: getResults());
533 // TODO: this is a pretty terrible API .. is there anything better?
534 for (auto pos = positions.find_last(); pos != -1;
535 pos = positions.find_prev(PriorTo: pos))
536 exprs.erase(CI: exprs.begin() + pos);
537 return AffineMap::get(dimCount: getNumDims(), symbolCount: getNumSymbols(), results: exprs, context: getContext());
538}
539
540AffineMap AffineMap::compose(AffineMap map) const {
541 assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
542 // Prepare `map` by concatenating the symbols and rewriting its exprs.
543 unsigned numDims = map.getNumDims();
544 unsigned numSymbolsThisMap = getNumSymbols();
545 unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols();
546 SmallVector<AffineExpr, 8> newDims(numDims);
547 for (unsigned idx = 0; idx < numDims; ++idx) {
548 newDims[idx] = getAffineDimExpr(position: idx, context: getContext());
549 }
550 SmallVector<AffineExpr, 8> newSymbols(numSymbols - numSymbolsThisMap);
551 for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
552 newSymbols[idx - numSymbolsThisMap] =
553 getAffineSymbolExpr(position: idx, context: getContext());
554 }
555 auto newMap =
556 map.replaceDimsAndSymbols(dimReplacements: newDims, symReplacements: newSymbols, numResultDims: numDims, numResultSyms: numSymbols);
557 SmallVector<AffineExpr, 8> exprs;
558 exprs.reserve(N: getResults().size());
559 for (auto expr : getResults())
560 exprs.push_back(Elt: expr.compose(map: newMap));
561 return AffineMap::get(dimCount: numDims, symbolCount: numSymbols, results: exprs, context: map.getContext());
562}
563
564SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
565 assert(getNumSymbols() == 0 && "Expected symbol-less map");
566 SmallVector<AffineExpr, 4> exprs;
567 exprs.reserve(N: values.size());
568 MLIRContext *ctx = getContext();
569 for (auto v : values)
570 exprs.push_back(Elt: getAffineConstantExpr(constant: v, context: ctx));
571 auto resMap = compose(map: AffineMap::get(dimCount: 0, symbolCount: 0, results: exprs, context: ctx));
572 SmallVector<int64_t, 4> res;
573 res.reserve(N: resMap.getNumResults());
574 for (auto e : resMap.getResults())
575 res.push_back(Elt: cast<AffineConstantExpr>(Val&: e).getValue());
576 return res;
577}
578
579bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
580 if (getNumSymbols() > 0)
581 return false;
582
583 // Having more results than inputs means that results have duplicated dims or
584 // zeros that can't be mapped to input dims.
585 if (getNumResults() > getNumInputs())
586 return false;
587
588 SmallVector<bool, 8> seen(getNumInputs(), false);
589 // A projected permutation can have, at most, only one instance of each input
590 // dimension in the result expressions. Zeros are allowed as long as the
591 // number of result expressions is lower or equal than the number of input
592 // expressions.
593 for (auto expr : getResults()) {
594 if (auto dim = dyn_cast<AffineDimExpr>(Val&: expr)) {
595 if (seen[dim.getPosition()])
596 return false;
597 seen[dim.getPosition()] = true;
598 } else {
599 auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr);
600 if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0)
601 return false;
602 }
603 }
604
605 // Results are either dims or zeros and zeros can be mapped to input dims.
606 return true;
607}
608
609bool AffineMap::isPermutation() const {
610 if (getNumDims() != getNumResults())
611 return false;
612 return isProjectedPermutation();
613}
614
615AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) const {
616 SmallVector<AffineExpr, 4> exprs;
617 exprs.reserve(N: resultPos.size());
618 for (auto idx : resultPos)
619 exprs.push_back(Elt: getResult(idx));
620 return AffineMap::get(dimCount: getNumDims(), symbolCount: getNumSymbols(), results: exprs, context: getContext());
621}
622
623AffineMap AffineMap::getSliceMap(unsigned start, unsigned length) const {
624 return AffineMap::get(dimCount: getNumDims(), symbolCount: getNumSymbols(),
625 results: getResults().slice(N: start, M: length), context: getContext());
626}
627
628AffineMap AffineMap::getMajorSubMap(unsigned numResults) const {
629 if (numResults == 0)
630 return AffineMap();
631 if (numResults > getNumResults())
632 return *this;
633 return getSliceMap(start: 0, length: numResults);
634}
635
636AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
637 if (numResults == 0)
638 return AffineMap();
639 if (numResults > getNumResults())
640 return *this;
641 return getSliceMap(start: getNumResults() - numResults, length: numResults);
642}
643
644/// Implementation detail to compress multiple affine maps with a compressionFun
645/// that is expected to be either compressUnusedDims or compressUnusedSymbols.
646/// The implementation keeps track of num dims and symbols across the different
647/// affine maps.
648static SmallVector<AffineMap> compressUnusedListImpl(
649 ArrayRef<AffineMap> maps,
650 llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
651 if (maps.empty())
652 return SmallVector<AffineMap>();
653 SmallVector<AffineExpr> allExprs;
654 allExprs.reserve(N: maps.size() * maps.front().getNumResults());
655 unsigned numDims = maps.front().getNumDims(),
656 numSymbols = maps.front().getNumSymbols();
657 for (auto m : maps) {
658 assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() &&
659 "expected maps with same num dims and symbols");
660 llvm::append_range(C&: allExprs, R: m.getResults());
661 }
662 AffineMap unifiedMap = compressionFun(
663 AffineMap::get(dimCount: numDims, symbolCount: numSymbols, results: allExprs, context: maps.front().getContext()));
664 unsigned unifiedNumDims = unifiedMap.getNumDims(),
665 unifiedNumSymbols = unifiedMap.getNumSymbols();
666 ArrayRef<AffineExpr> unifiedResults = unifiedMap.getResults();
667 SmallVector<AffineMap> res;
668 res.reserve(N: maps.size());
669 for (auto m : maps) {
670 res.push_back(Elt: AffineMap::get(dimCount: unifiedNumDims, symbolCount: unifiedNumSymbols,
671 results: unifiedResults.take_front(N: m.getNumResults()),
672 context: m.getContext()));
673 unifiedResults = unifiedResults.drop_front(N: m.getNumResults());
674 }
675 return res;
676}
677
678AffineMap mlir::compressDims(AffineMap map,
679 const llvm::SmallBitVector &unusedDims) {
680 return projectDims(map, projectedDimensions: unusedDims, /*compressDimsFlag=*/true);
681}
682
683AffineMap mlir::compressUnusedDims(AffineMap map) {
684 return compressDims(map, unusedDims: getUnusedDimsBitVector(maps: {map}));
685}
686
687SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) {
688 return compressUnusedListImpl(
689 maps, compressionFun: [](AffineMap m) { return compressUnusedDims(map: m); });
690}
691
692AffineMap mlir::compressSymbols(AffineMap map,
693 const llvm::SmallBitVector &unusedSymbols) {
694 return projectSymbols(map, projectedSymbols: unusedSymbols, /*compressSymbolsFlag=*/true);
695}
696
697AffineMap mlir::compressUnusedSymbols(AffineMap map) {
698 return compressSymbols(map, unusedSymbols: getUnusedSymbolsBitVector(maps: {map}));
699}
700
701SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) {
702 return compressUnusedListImpl(
703 maps, compressionFun: [](AffineMap m) { return compressUnusedSymbols(map: m); });
704}
705
706AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map,
707 ArrayRef<OpFoldResult> operands,
708 SmallVector<Value> &remainingValues) {
709 SmallVector<AffineExpr> dimReplacements, symReplacements;
710 int64_t numDims = 0;
711 for (int64_t i = 0; i < map.getNumDims(); ++i) {
712 if (auto attr = operands[i].dyn_cast<Attribute>()) {
713 dimReplacements.push_back(
714 Elt: b.getAffineConstantExpr(constant: cast<IntegerAttr>(attr).getInt()));
715 } else {
716 dimReplacements.push_back(Elt: b.getAffineDimExpr(position: numDims++));
717 remainingValues.push_back(Elt: operands[i].get<Value>());
718 }
719 }
720 int64_t numSymbols = 0;
721 for (int64_t i = 0; i < map.getNumSymbols(); ++i) {
722 if (auto attr = operands[i + map.getNumDims()].dyn_cast<Attribute>()) {
723 symReplacements.push_back(
724 Elt: b.getAffineConstantExpr(constant: cast<IntegerAttr>(attr).getInt()));
725 } else {
726 symReplacements.push_back(Elt: b.getAffineSymbolExpr(position: numSymbols++));
727 remainingValues.push_back(Elt: operands[i + map.getNumDims()].get<Value>());
728 }
729 }
730 return map.replaceDimsAndSymbols(dimReplacements, symReplacements, numResultDims: numDims,
731 numResultSyms: numSymbols);
732}
733
734AffineMap mlir::simplifyAffineMap(AffineMap map) {
735 SmallVector<AffineExpr, 8> exprs;
736 for (auto e : map.getResults()) {
737 exprs.push_back(
738 Elt: simplifyAffineExpr(expr: e, numDims: map.getNumDims(), numSymbols: map.getNumSymbols()));
739 }
740 return AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: exprs,
741 context: map.getContext());
742}
743
744AffineMap mlir::removeDuplicateExprs(AffineMap map) {
745 auto results = map.getResults();
746 SmallVector<AffineExpr, 4> uniqueExprs(results.begin(), results.end());
747 uniqueExprs.erase(CS: std::unique(first: uniqueExprs.begin(), last: uniqueExprs.end()),
748 CE: uniqueExprs.end());
749 return AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: uniqueExprs,
750 context: map.getContext());
751}
752
753AffineMap mlir::inversePermutation(AffineMap map) {
754 if (map.isEmpty())
755 return map;
756 assert(map.getNumSymbols() == 0 && "expected map without symbols");
757 SmallVector<AffineExpr, 4> exprs(map.getNumDims());
758 for (const auto &en : llvm::enumerate(First: map.getResults())) {
759 auto expr = en.value();
760 // Skip non-permutations.
761 if (auto d = dyn_cast<AffineDimExpr>(Val&: expr)) {
762 if (exprs[d.getPosition()])
763 continue;
764 exprs[d.getPosition()] = getAffineDimExpr(position: en.index(), context: d.getContext());
765 }
766 }
767 SmallVector<AffineExpr, 4> seenExprs;
768 seenExprs.reserve(N: map.getNumDims());
769 for (auto expr : exprs)
770 if (expr)
771 seenExprs.push_back(Elt: expr);
772 if (seenExprs.size() != map.getNumInputs())
773 return AffineMap();
774 return AffineMap::get(dimCount: map.getNumResults(), symbolCount: 0, results: seenExprs, context: map.getContext());
775}
776
777AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) {
778 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true));
779 MLIRContext *context = map.getContext();
780 AffineExpr zero = mlir::getAffineConstantExpr(constant: 0, context);
781 // Start with all the results as 0.
782 SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero);
783 for (unsigned i : llvm::seq(Begin: unsigned(0), End: map.getNumResults())) {
784 // Skip zeros from input map. 'exprs' is already initialized to zero.
785 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val: map.getResult(idx: i))) {
786 assert(constExpr.getValue() == 0 &&
787 "Unexpected constant in projected permutation");
788 (void)constExpr;
789 continue;
790 }
791
792 // Reverse each dimension existing in the original map result.
793 exprs[map.getDimPosition(idx: i)] = getAffineDimExpr(position: i, context);
794 }
795 return AffineMap::get(dimCount: map.getNumResults(), /*symbolCount=*/0, results: exprs, context);
796}
797
798AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
799 unsigned numResults = 0, numDims = 0, numSymbols = 0;
800 for (auto m : maps)
801 numResults += m.getNumResults();
802 SmallVector<AffineExpr, 8> results;
803 results.reserve(N: numResults);
804 for (auto m : maps) {
805 for (auto res : m.getResults())
806 results.push_back(Elt: res.shiftSymbols(numSymbols: m.getNumSymbols(), shift: numSymbols));
807
808 numSymbols += m.getNumSymbols();
809 numDims = std::max(a: m.getNumDims(), b: numDims);
810 }
811 return AffineMap::get(dimCount: numDims, symbolCount: numSymbols, results,
812 context: maps.front().getContext());
813}
814
815/// Common implementation to project out dimensions or symbols from an affine
816/// map based on the template type.
817/// Additionally, if 'compress' is true, the projected out dimensions or symbols
818/// are also dropped from the resulting map.
819template <typename AffineDimOrSymExpr>
820static AffineMap projectCommonImpl(AffineMap map,
821 const llvm::SmallBitVector &toProject,
822 bool compress) {
823 static_assert(llvm::is_one_of<AffineDimOrSymExpr, AffineDimExpr,
824 AffineSymbolExpr>::value,
825 "expected AffineDimExpr or AffineSymbolExpr");
826
827 constexpr bool isDim = std::is_same<AffineDimOrSymExpr, AffineDimExpr>::value;
828 int64_t numDimOrSym = (isDim) ? map.getNumDims() : map.getNumSymbols();
829 SmallVector<AffineExpr> replacements;
830 replacements.reserve(N: numDimOrSym);
831
832 auto createNewDimOrSym = (isDim) ? getAffineDimExpr : getAffineSymbolExpr;
833
834 using replace_fn_ty =
835 std::function<AffineExpr(AffineExpr, ArrayRef<AffineExpr>)>;
836 replace_fn_ty replaceDims = [](AffineExpr e,
837 ArrayRef<AffineExpr> replacements) {
838 return e.replaceDims(dimReplacements: replacements);
839 };
840 replace_fn_ty replaceSymbols = [](AffineExpr e,
841 ArrayRef<AffineExpr> replacements) {
842 return e.replaceSymbols(symReplacements: replacements);
843 };
844 replace_fn_ty replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols;
845
846 MLIRContext *context = map.getContext();
847 int64_t newNumDimOrSym = 0;
848 for (unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) {
849 if (toProject.test(Idx: dimOrSym)) {
850 replacements.push_back(Elt: getAffineConstantExpr(constant: 0, context));
851 continue;
852 }
853 int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym;
854 replacements.push_back(Elt: createNewDimOrSym(newPos, context));
855 }
856 SmallVector<AffineExpr> resultExprs;
857 resultExprs.reserve(N: map.getNumResults());
858 for (auto e : map.getResults())
859 resultExprs.push_back(Elt: replaceNewDimOrSym(e, replacements));
860
861 int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.getNumDims();
862 int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.getNumSymbols();
863 return AffineMap::get(dimCount: numDims, symbolCount: numSyms, results: resultExprs, context);
864}
865
866AffineMap mlir::projectDims(AffineMap map,
867 const llvm::SmallBitVector &projectedDimensions,
868 bool compressDimsFlag) {
869 return projectCommonImpl<AffineDimExpr>(map, toProject: projectedDimensions,
870 compress: compressDimsFlag);
871}
872
873AffineMap mlir::projectSymbols(AffineMap map,
874 const llvm::SmallBitVector &projectedSymbols,
875 bool compressSymbolsFlag) {
876 return projectCommonImpl<AffineSymbolExpr>(map, toProject: projectedSymbols,
877 compress: compressSymbolsFlag);
878}
879
880AffineMap mlir::getProjectedMap(AffineMap map,
881 const llvm::SmallBitVector &projectedDimensions,
882 bool compressDimsFlag,
883 bool compressSymbolsFlag) {
884 map = projectDims(map, projectedDimensions, compressDimsFlag);
885 if (compressSymbolsFlag)
886 map = compressUnusedSymbols(map);
887 return map;
888}
889
890llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
891 unsigned numDims = maps[0].getNumDims();
892 llvm::SmallBitVector numDimsBitVector(numDims, true);
893 for (AffineMap m : maps) {
894 for (unsigned i = 0; i < numDims; ++i) {
895 if (m.isFunctionOfDim(position: i))
896 numDimsBitVector.reset(Idx: i);
897 }
898 }
899 return numDimsBitVector;
900}
901
902llvm::SmallBitVector mlir::getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps) {
903 unsigned numSymbols = maps[0].getNumSymbols();
904 llvm::SmallBitVector numSymbolsBitVector(numSymbols, true);
905 for (AffineMap m : maps) {
906 for (unsigned i = 0; i < numSymbols; ++i) {
907 if (m.isFunctionOfSymbol(position: i))
908 numSymbolsBitVector.reset(Idx: i);
909 }
910 }
911 return numSymbolsBitVector;
912}
913
914AffineMap
915mlir::expandDimsToRank(AffineMap map, int64_t rank,
916 const llvm::SmallBitVector &projectedDimensions) {
917 auto id = AffineMap::getMultiDimIdentityMap(numDims: rank, context: map.getContext());
918 AffineMap proj = id.dropResults(positions: projectedDimensions);
919 return map.compose(map: proj);
920}
921
922//===----------------------------------------------------------------------===//
923// MutableAffineMap.
924//===----------------------------------------------------------------------===//
925
926MutableAffineMap::MutableAffineMap(AffineMap map)
927 : results(map.getResults().begin(), map.getResults().end()),
928 numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
929 context(map.getContext()) {}
930
931void MutableAffineMap::reset(AffineMap map) {
932 results.clear();
933 numDims = map.getNumDims();
934 numSymbols = map.getNumSymbols();
935 context = map.getContext();
936 llvm::append_range(C&: results, R: map.getResults());
937}
938
939bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
940 return results[idx].isMultipleOf(factor);
941}
942
943// Simplifies the result affine expressions of this map. The expressions
944// have to be pure for the simplification implemented.
945void MutableAffineMap::simplify() {
946 // Simplify each of the results if possible.
947 // TODO: functional-style map
948 for (unsigned i = 0, e = getNumResults(); i < e; i++) {
949 results[i] = simplifyAffineExpr(expr: getResult(idx: i), numDims, numSymbols);
950 }
951}
952
953AffineMap MutableAffineMap::getAffineMap() const {
954 return AffineMap::get(dimCount: numDims, symbolCount: numSymbols, results, context);
955}
956

source code of mlir/lib/IR/AffineMap.cpp