1//===- DecomposeAffineOps.cpp - Decompose affine ops into finer-grained ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functionality to progressively decompose coarse-grained
10// affine ops into finer-grained ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Affine/Transforms/Transforms.h"
16#include "mlir/IR/PatternMatch.h"
17#include "llvm/Support/Debug.h"
18
19using namespace mlir;
20using namespace mlir::affine;
21
22#define DEBUG_TYPE "decompose-affine-ops"
23#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
24#define DBGSNL() (llvm::dbgs() << "\n")
25
26/// Count the number of loops surrounding `operand` such that operand could be
27/// hoisted above.
28/// Stop counting at the first loop over which the operand cannot be hoisted.
29static int64_t numEnclosingInvariantLoops(OpOperand &operand) {
30 int64_t count = 0;
31 Operation *currentOp = operand.getOwner();
32 while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) {
33 if (!loopOp.isDefinedOutsideOfLoop(value: operand.get()))
34 break;
35 currentOp = loopOp;
36 count++;
37 }
38 return count;
39}
40
41void mlir::affine::reorderOperandsByHoistability(RewriterBase &rewriter,
42 AffineApplyOp op) {
43 SmallVector<int64_t> numInvariant = llvm::to_vector(
44 Range: llvm::map_range(C: op->getOpOperands(), F: [&](OpOperand &operand) {
45 return numEnclosingInvariantLoops(operand);
46 }));
47
48 int64_t numOperands = op.getNumOperands();
49 SmallVector<int64_t> operandPositions =
50 llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: numOperands));
51 llvm::stable_sort(Range&: operandPositions, C: [&numInvariant](size_t i1, size_t i2) {
52 return numInvariant[i1] > numInvariant[i2];
53 });
54
55 SmallVector<AffineExpr> replacements(numOperands);
56 SmallVector<Value> operands(numOperands);
57 for (int64_t i = 0; i < numOperands; ++i) {
58 operands[i] = op.getOperand(i: operandPositions[i]);
59 replacements[operandPositions[i]] = getAffineSymbolExpr(position: i, context: op.getContext());
60 }
61
62 AffineMap map = op.getAffineMap();
63 ArrayRef<AffineExpr> repls{replacements};
64 map = map.replaceDimsAndSymbols(dimReplacements: repls.take_front(N: map.getNumDims()),
65 symReplacements: repls.drop_front(N: map.getNumDims()),
66 /*numResultDims=*/0,
67 /*numResultSyms=*/numOperands);
68 map = AffineMap::get(dimCount: 0, symbolCount: numOperands,
69 results: simplifyAffineExpr(expr: map.getResult(idx: 0), numDims: 0, numSymbols: numOperands),
70 context: op->getContext());
71 canonicalizeMapAndOperands(map: &map, operands: &operands);
72
73 rewriter.startOpModification(op);
74 op.setMap(map);
75 op->setOperands(operands);
76 rewriter.finalizeOpModification(op);
77}
78
79/// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine
80/// map and with the same operands.
81/// Canonicalize the map and operands to deduplicate and drop dead operands
82/// before returning but do not perform maximal composition of AffineApplyOp
83/// which would defeat the purpose.
84static AffineApplyOp createSubApply(RewriterBase &rewriter,
85 AffineApplyOp originalOp, AffineExpr expr) {
86 MLIRContext *ctx = originalOp->getContext();
87 AffineMap m = originalOp.getAffineMap();
88 auto rhsMap = AffineMap::get(dimCount: m.getNumDims(), symbolCount: m.getNumSymbols(), results: expr, context: ctx);
89 SmallVector<Value> rhsOperands = originalOp->getOperands();
90 canonicalizeMapAndOperands(map: &rhsMap, operands: &rhsOperands);
91 return rewriter.create<AffineApplyOp>(location: originalOp.getLoc(), args&: rhsMap,
92 args&: rhsOperands);
93}
94
95FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
96 AffineApplyOp op) {
97 // 1. Preconditions: only handle dimensionless AffineApplyOp maps with a
98 // top-level binary expression that we can reassociate (i.e. add or mul).
99 AffineMap m = op.getAffineMap();
100 if (m.getNumDims() > 0)
101 return rewriter.notifyMatchFailure(arg&: op, msg: "expected no dims");
102
103 AffineExpr remainingExp = m.getResult(idx: 0);
104 auto binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: remainingExp);
105 if (!binExpr)
106 return rewriter.notifyMatchFailure(arg&: op, msg: "terminal affine.apply");
107
108 if (!isa<AffineBinaryOpExpr>(Val: binExpr.getLHS()) &&
109 !isa<AffineBinaryOpExpr>(Val: binExpr.getRHS()))
110 return rewriter.notifyMatchFailure(arg&: op, msg: "terminal affine.apply");
111
112 bool supportedKind = ((binExpr.getKind() == AffineExprKind::Add) ||
113 (binExpr.getKind() == AffineExprKind::Mul));
114 if (!supportedKind)
115 return rewriter.notifyMatchFailure(
116 arg&: op, msg: "only add or mul binary expr can be reassociated");
117
118 LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n");
119
120 // 2. Iteratively extract the RHS subexpressions while the top-level binary
121 // expr kind remains the same.
122 MLIRContext *ctx = op->getContext();
123 SmallVector<AffineExpr> subExpressions;
124 while (true) {
125 auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(Val&: remainingExp);
126 if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) {
127 subExpressions.push_back(Elt: remainingExp);
128 LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n");
129 break;
130 }
131 subExpressions.push_back(Elt: currentBinExpr.getRHS());
132 LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n");
133 remainingExp = currentBinExpr.getLHS();
134 }
135
136 // 3. Reorder subExpressions by the min symbol they are a function of.
137 // This also takes care of properly reordering local variables.
138 // This however won't be able to split expression that cannot be reassociated
139 // such as ones that involve divs and multiple symbols.
140 auto getMaxSymbol = [&](AffineExpr e) -> int64_t {
141 for (int64_t i = m.getNumSymbols(); i >= 0; --i)
142 if (e.isFunctionOfSymbol(position: i))
143 return i;
144 return -1;
145 };
146 llvm::stable_sort(Range&: subExpressions, C: [&](AffineExpr e1, AffineExpr e2) {
147 return getMaxSymbol(e1) < getMaxSymbol(e2);
148 });
149 LLVM_DEBUG(
150 llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: ");
151 llvm::dbgs() << "\n");
152
153 // 4. Merge sorted subExpressions iteratively, thus achieving reassociation.
154 auto s0 = getAffineSymbolExpr(position: 0, context: ctx);
155 auto s1 = getAffineSymbolExpr(position: 1, context: ctx);
156 AffineMap binMap = AffineMap::get(
157 /*dimCount=*/0, /*symbolCount=*/2,
158 results: getAffineBinaryOpExpr(kind: binExpr.getKind(), lhs: s0, rhs: s1), context: ctx);
159
160 auto current = createSubApply(rewriter, originalOp: op, expr: subExpressions[0]);
161 for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) {
162 Value tmp = createSubApply(rewriter, originalOp: op, expr: subExpressions[i]);
163 current = rewriter.create<AffineApplyOp>(location: op.getLoc(), args&: binMap,
164 args: ValueRange{current, tmp});
165 LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n");
166 }
167
168 // 5. Replace original op.
169 rewriter.replaceOp(op, newValues: current.getResult());
170 return current;
171}
172

source code of mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp