1//===- DecomposeAffineOps.cpp - Decompose affine ops into finer-grained ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functionality to progressively decompose coarse-grained
10// affine ops into finer-grained ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Affine/Transforms/Transforms.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18#include "llvm/Support/Debug.h"
19
20using namespace mlir;
21using namespace mlir::affine;
22
23#define DEBUG_TYPE "decompose-affine-ops"
24#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25#define DBGSNL() (llvm::dbgs() << "\n")
26
27/// Count the number of loops surrounding `operand` such that operand could be
28/// hoisted above.
29/// Stop counting at the first loop over which the operand cannot be hoisted.
30static int64_t numEnclosingInvariantLoops(OpOperand &operand) {
31 int64_t count = 0;
32 Operation *currentOp = operand.getOwner();
33 while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) {
34 if (!loopOp.isDefinedOutsideOfLoop(operand.get()))
35 break;
36 currentOp = loopOp;
37 count++;
38 }
39 return count;
40}
41
42void mlir::affine::reorderOperandsByHoistability(RewriterBase &rewriter,
43 AffineApplyOp op) {
44 SmallVector<int64_t> numInvariant = llvm::to_vector(
45 llvm::map_range(op->getOpOperands(), [&](OpOperand &operand) {
46 return numEnclosingInvariantLoops(operand);
47 }));
48
49 int64_t numOperands = op.getNumOperands();
50 SmallVector<int64_t> operandPositions =
51 llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: numOperands));
52 llvm::stable_sort(Range&: operandPositions, C: [&numInvariant](size_t i1, size_t i2) {
53 return numInvariant[i1] > numInvariant[i2];
54 });
55
56 SmallVector<AffineExpr> replacements(numOperands);
57 SmallVector<Value> operands(numOperands);
58 for (int64_t i = 0; i < numOperands; ++i) {
59 operands[i] = op.getOperand(operandPositions[i]);
60 replacements[operandPositions[i]] = getAffineSymbolExpr(i, op.getContext());
61 }
62
63 AffineMap map = op.getAffineMap();
64 ArrayRef<AffineExpr> repls{replacements};
65 map = map.replaceDimsAndSymbols(dimReplacements: repls.take_front(N: map.getNumDims()),
66 symReplacements: repls.drop_front(N: map.getNumDims()),
67 /*numResultDims=*/0,
68 /*numResultSyms=*/numOperands);
69 map = AffineMap::get(0, numOperands,
70 simplifyAffineExpr(expr: map.getResult(idx: 0), numDims: 0, numSymbols: numOperands),
71 op->getContext());
72 canonicalizeMapAndOperands(map: &map, operands: &operands);
73
74 rewriter.startOpModification(op: op);
75 op.setMap(map);
76 op->setOperands(operands);
77 rewriter.finalizeOpModification(op: op);
78}
79
80/// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine
81/// map and with the same operands.
82/// Canonicalize the map and operands to deduplicate and drop dead operands
83/// before returning but do not perform maximal composition of AffineApplyOp
84/// which would defeat the purpose.
85static AffineApplyOp createSubApply(RewriterBase &rewriter,
86 AffineApplyOp originalOp, AffineExpr expr) {
87 MLIRContext *ctx = originalOp->getContext();
88 AffineMap m = originalOp.getAffineMap();
89 auto rhsMap = AffineMap::get(dimCount: m.getNumDims(), symbolCount: m.getNumSymbols(), results: expr, context: ctx);
90 SmallVector<Value> rhsOperands = originalOp->getOperands();
91 canonicalizeMapAndOperands(&rhsMap, &rhsOperands);
92 return rewriter.create<AffineApplyOp>(originalOp.getLoc(), rhsMap,
93 rhsOperands);
94}
95
96FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
97 AffineApplyOp op) {
98 // 1. Preconditions: only handle dimensionless AffineApplyOp maps with a
99 // top-level binary expression that we can reassociate (i.e. add or mul).
100 AffineMap m = op.getAffineMap();
101 if (m.getNumDims() > 0)
102 return rewriter.notifyMatchFailure(op, "expected no dims");
103
104 AffineExpr remainingExp = m.getResult(idx: 0);
105 auto binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: remainingExp);
106 if (!binExpr)
107 return rewriter.notifyMatchFailure(op, "terminal affine.apply");
108
109 if (!isa<AffineBinaryOpExpr>(binExpr.getLHS()) &&
110 !isa<AffineBinaryOpExpr>(binExpr.getRHS()))
111 return rewriter.notifyMatchFailure(op, "terminal affine.apply");
112
113 bool supportedKind = ((binExpr.getKind() == AffineExprKind::Add) ||
114 (binExpr.getKind() == AffineExprKind::Mul));
115 if (!supportedKind)
116 return rewriter.notifyMatchFailure(
117 op, "only add or mul binary expr can be reassociated");
118
119 LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n");
120
121 // 2. Iteratively extract the RHS subexpressions while the top-level binary
122 // expr kind remains the same.
123 MLIRContext *ctx = op->getContext();
124 SmallVector<AffineExpr> subExpressions;
125 while (true) {
126 auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(Val&: remainingExp);
127 if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) {
128 subExpressions.push_back(Elt: remainingExp);
129 LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n");
130 break;
131 }
132 subExpressions.push_back(Elt: currentBinExpr.getRHS());
133 LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n");
134 remainingExp = currentBinExpr.getLHS();
135 }
136
137 // 3. Reorder subExpressions by the min symbol they are a function of.
138 // This also takes care of properly reordering local variables.
139 // This however won't be able to split expression that cannot be reassociated
140 // such as ones that involve divs and multiple symbols.
141 auto getMaxSymbol = [&](AffineExpr e) -> int64_t {
142 for (int64_t i = m.getNumSymbols(); i >= 0; --i)
143 if (e.isFunctionOfSymbol(position: i))
144 return i;
145 return -1;
146 };
147 llvm::stable_sort(Range&: subExpressions, C: [&](AffineExpr e1, AffineExpr e2) {
148 return getMaxSymbol(e1) < getMaxSymbol(e2);
149 });
150 LLVM_DEBUG(
151 llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: ");
152 llvm::dbgs() << "\n");
153
154 // 4. Merge sorted subExpressions iteratively, thus achieving reassociation.
155 auto s0 = getAffineSymbolExpr(position: 0, context: ctx);
156 auto s1 = getAffineSymbolExpr(position: 1, context: ctx);
157 AffineMap binMap = AffineMap::get(
158 /*dimCount=*/0, /*symbolCount=*/2,
159 getAffineBinaryOpExpr(binExpr.getKind(), s0, s1), ctx);
160
161 auto current = createSubApply(rewriter, op, subExpressions[0]);
162 for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) {
163 Value tmp = createSubApply(rewriter, op, subExpressions[i]);
164 current = rewriter.create<AffineApplyOp>(op.getLoc(), binMap,
165 ValueRange{current, tmp});
166 LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n");
167 }
168
169 // 5. Replace original op.
170 rewriter.replaceOp(op, current.getResult());
171 return current;
172}
173

source code of mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp