1 | //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
16 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
17 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
18 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
19 | #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" |
20 | #include "mlir/IR/AffineExpr.h" |
21 | #include "mlir/IR/AffineMap.h" |
22 | #include "mlir/IR/IRMapping.h" |
23 | #include "mlir/Support/LLVM.h" |
24 | #include "mlir/Transforms/DialectConversion.h" |
25 | #include "mlir/Transforms/FoldUtils.h" |
26 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
27 | #include "llvm/ADT/TypeSwitch.h" |
28 | |
29 | namespace mlir { |
30 | #define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS |
31 | #define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS |
32 | #define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS |
33 | #include "mlir/Dialect/Linalg/Passes.h.inc" |
34 | } // namespace mlir |
35 | |
36 | using namespace mlir; |
37 | using namespace mlir::linalg; |
38 | |
39 | static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc, |
40 | AffineMap map, |
41 | ArrayRef<Value> vals) { |
42 | if (map.isEmpty()) |
43 | return {}; |
44 | |
45 | assert(map.getNumInputs() == vals.size()); |
46 | SmallVector<Value> res; |
47 | res.reserve(map.getNumResults()); |
48 | auto dims = map.getNumDims(); |
49 | for (auto e : map.getResults()) { |
50 | auto exprMap = AffineMap::get(dimCount: dims, symbolCount: map.getNumSymbols(), result: e); |
51 | SmallVector<Value> operands(vals.begin(), vals.end()); |
52 | affine::canonicalizeMapAndOperands(map: &exprMap, operands: &operands); |
53 | res.push_back(b.create<affine::AffineApplyOp>(loc, exprMap, operands)); |
54 | } |
55 | return res; |
56 | } |
57 | |
58 | template <typename LoadOpTy, typename StoreOpTy, typename OpType> |
59 | static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, |
60 | ArrayRef<Value> indexedValues, |
61 | ArrayRef<SmallVector<Value>> indexing, |
62 | ArrayRef<Value> outputBuffers) { |
63 | auto &block = op->getRegion(0).front(); |
64 | IRMapping map; |
65 | map.map(block.getArguments(), indexedValues); |
66 | for (auto &op : block.without_terminator()) { |
67 | auto *newOp = b.clone(op, map); |
68 | map.map(op.getResults(), newOp->getResults()); |
69 | } |
70 | |
71 | Operation *terminator = block.getTerminator(); |
72 | for (OpOperand &operand : terminator->getOpOperands()) { |
73 | Value toStore = map.lookupOrDefault(from: operand.get()); |
74 | b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()], |
75 | indexing[operand.getOperandNumber()]); |
76 | } |
77 | } |
78 | |
79 | // Returns a pair that contains input indices and output indices of a |
80 | // SingleInputPoolingOp `op`. |
81 | struct InputAndOutputIndices { |
82 | SmallVector<Value> inputs; |
83 | SmallVector<Value> outputs; |
84 | }; |
85 | template <typename SingleInputPoolingOp> |
86 | static InputAndOutputIndices |
87 | getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs, |
88 | SingleInputPoolingOp op) { |
89 | auto mapsRange = op.getIndexingMapsArray(); |
90 | auto maps = llvm::to_vector<8>( |
91 | llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); |
92 | return InputAndOutputIndices{ |
93 | makeCanonicalAffineApplies(b, loc, maps[0], allIvs), |
94 | makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; |
95 | } |
96 | |
97 | /// Emits the MLIR for the scalar part of the generic op by: |
98 | /// 1. Emitting load ops for each input and output view in order. This is |
99 | /// achieved by applying the appropriate input or output map to the |
100 | /// enclosing induction variables. |
101 | /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars |
102 | /// from point 1. above. |
103 | /// 3. Emitting store ops to store the results of 2. to the output |
104 | /// views. |
105 | /// |
106 | /// An example output may resemble: |
107 | /// |
108 | /// ``` |
109 | /// scf.for %i = %c0 to %0 step %c1 { |
110 | /// scf.for %j = %c0 to %1 step %c1 { |
111 | /// scf.for %k = %c0 to %4 step %c1 { |
112 | /// %11 = load %arg0[%i, %j] : |
113 | /// memref<?x?xf32, stride_specification> |
114 | /// %12 = load %arg1[%i, %j, %k] : |
115 | /// memref<?x?x?xf32, stride_specification> |
116 | /// %13 = load %arg2[%i, %k, %j] : |
117 | /// memref<?x?x?xf32, stride_specification> |
118 | /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) |
119 | /// store %14#0, %arg1[%i, %j, %k] : |
120 | /// memref<?x?x?Xf32, stride_specification> |
121 | /// store %14#1, %arg2[%i, %k, %j] : |
122 | /// memref<?x?x?Xf32, stride_specification> |
123 | /// } |
124 | /// } |
125 | /// } |
126 | /// ``` |
127 | template <typename LoadOpTy, typename StoreOpTy> |
128 | static void emitScalarImplementation(OpBuilder &b, Location loc, |
129 | ArrayRef<Value> allIvs, |
130 | LinalgOp linalgOp) { |
131 | assert(linalgOp.hasPureBufferSemantics() && |
132 | "expected linalg op with buffer semantics" ); |
133 | SmallVector<Value> indexedValues; |
134 | indexedValues.reserve(N: linalgOp->getNumOperands()); |
135 | |
136 | auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end()); |
137 | |
138 | // TODO: Avoid the loads if the corresponding argument of the |
139 | // region has no uses. |
140 | // 1.a. Emit load from input operand or for scalars access the operand itself. |
141 | for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) { |
142 | if (linalgOp.isScalar(inputOperand)) { |
143 | indexedValues.push_back(inputOperand->get()); |
144 | continue; |
145 | } |
146 | auto indexing = makeCanonicalAffineApplies( |
147 | b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims); |
148 | indexedValues.push_back( |
149 | b.create<LoadOpTy>(loc, inputOperand->get(), indexing)); |
150 | } |
151 | // 1.b. Emit load from output views. |
152 | for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) { |
153 | SmallVector<Value> indexing = makeCanonicalAffineApplies( |
154 | b, loc, linalgOp.getMatchingIndexingMap(&outputOperand), |
155 | allIvsPlusDims); |
156 | indexedValues.push_back( |
157 | b.create<LoadOpTy>(loc, outputOperand.get(), indexing)); |
158 | } |
159 | |
160 | // TODO: When a region inliner exists, use it. |
161 | // 2. Inline region, currently only works for a single basic block. |
162 | // 3. Emit store. |
163 | SmallVector<SmallVector<Value>, 8> indexing; |
164 | SmallVector<Value> outputBuffers; |
165 | for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) { |
166 | if (!isa<MemRefType>(outputOperand.get().getType())) |
167 | continue; |
168 | indexing.push_back(makeCanonicalAffineApplies( |
169 | b, loc, linalgOp.getMatchingIndexingMap(&outputOperand), |
170 | allIvsPlusDims)); |
171 | outputBuffers.push_back(outputOperand.get()); |
172 | } |
173 | inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues, |
174 | indexing, outputBuffers); |
175 | } |
176 | |
177 | /// Replace the index operations in the body of the loop nest by the matching |
178 | /// induction variables. |
179 | static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, |
180 | LinalgOp linalgOp, |
181 | ArrayRef<Operation *> loopOps) { |
182 | // Extract the induction variables of the loop nest from outer to inner. |
183 | SmallVector<Value> allIvs; |
184 | for (Operation *loopOp : loopOps) { |
185 | llvm::TypeSwitch<Operation *>(loopOp) |
186 | .Case(caseFn: [&](scf::ParallelOp parallelOp) { |
187 | allIvs.append(parallelOp.getInductionVars().begin(), |
188 | parallelOp.getInductionVars().end()); |
189 | }) |
190 | .Case(caseFn: [&](scf::ForOp forOp) { |
191 | allIvs.push_back(Elt: forOp.getInductionVar()); |
192 | }) |
193 | .Case(caseFn: [&](affine::AffineForOp affineForOp) { |
194 | allIvs.push_back(Elt: affineForOp.getInductionVar()); |
195 | }) |
196 | .Default(defaultFn: [&](Operation *op) { assert(false && "unexpected op" ); }); |
197 | } |
198 | assert(linalgOp.getNumLoops() == allIvs.size() && |
199 | "expected the number of loops and induction variables to match" ); |
200 | // Replace the index operations in the body of the innermost loop op. |
201 | if (!loopOps.empty()) { |
202 | auto loopOp = cast<LoopLikeOpInterface>(loopOps.back()); |
203 | for (Region *r : loopOp.getLoopRegions()) |
204 | for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>())) |
205 | rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]); |
206 | } |
207 | } |
208 | |
209 | template <typename LoopTy> |
210 | static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter, |
211 | LinalgOp linalgOp) { |
212 | using LoadOpTy = |
213 | std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value, |
214 | affine::AffineLoadOp, memref::LoadOp>; |
215 | using StoreOpTy = |
216 | std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value, |
217 | affine::AffineStoreOp, memref::StoreOp>; |
218 | |
219 | // The flattened loopToOperandRangesMaps is expected to be an invertible |
220 | // permutation map (which is asserted in the inverse calculation). |
221 | assert(linalgOp.hasPureBufferSemantics() && |
222 | "expected linalg op with buffer semantics" ); |
223 | |
224 | auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); |
225 | auto iteratorTypes = linalgOp.getIteratorTypesArray(); |
226 | |
227 | SmallVector<Value> allIvs; |
228 | GenerateLoopNest<LoopTy>::doit( |
229 | rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, |
230 | [&](OpBuilder &b, Location loc, ValueRange ivs, |
231 | ValueRange operandValuesToUse) -> scf::ValueVector { |
232 | assert(operandValuesToUse == linalgOp->getOperands() && |
233 | "expect operands are captured and not passed by loop argument" ); |
234 | allIvs.append(in_start: ivs.begin(), in_end: ivs.end()); |
235 | emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp); |
236 | return scf::ValueVector{}; |
237 | }); |
238 | // Number of loop ops might be different from the number of ivs since some |
239 | // loops like affine.parallel and scf.parallel have multiple ivs. |
240 | SetVector<Operation *> loopSet; |
241 | for (Value iv : allIvs) { |
242 | if (!iv) |
243 | return failure(); |
244 | // The induction variable is a block argument of the entry block of the |
245 | // loop operation. |
246 | BlockArgument ivVal = dyn_cast<BlockArgument>(Val&: iv); |
247 | if (!ivVal) |
248 | return failure(); |
249 | loopSet.insert(X: ivVal.getOwner()->getParentOp()); |
250 | } |
251 | LinalgLoops loops(loopSet.begin(), loopSet.end()); |
252 | // Replace all index operations in the loop body. |
253 | replaceIndexOpsByInductionVariables(rewriter, linalgOp, loops); |
254 | return loops; |
255 | } |
256 | |
257 | namespace { |
258 | template <typename LoopType> |
259 | class LinalgRewritePattern : public RewritePattern { |
260 | public: |
261 | LinalgRewritePattern(MLIRContext *context) |
262 | : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} |
263 | |
264 | LogicalResult matchAndRewrite(Operation *op, |
265 | PatternRewriter &rewriter) const override { |
266 | auto linalgOp = dyn_cast<LinalgOp>(op); |
267 | if (!isa<LinalgOp>(Val: op) || !linalgOp.hasPureBufferSemantics()) { |
268 | return rewriter.notifyMatchFailure( |
269 | arg&: op, msg: "expected linalg op with buffer semantics" ); |
270 | } |
271 | if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp))) |
272 | return failure(); |
273 | rewriter.eraseOp(op); |
274 | return success(); |
275 | } |
276 | }; |
277 | |
278 | /// Local folding pattern for AffineApplyOp that we can apply greedily. |
279 | /// This replaces AffineApplyOp by the proper value in cases where the |
280 | /// associated map is trivial. |
281 | /// A trivial map here is defined as a map with a single result and either: |
282 | /// 1. Zero operand + returns a single AffineConstantExpr |
283 | /// 2. One operand + returns a single AffineDimExpr |
284 | /// 3. One operand + returns a single AffineSymbolExpr |
285 | // |
286 | /// In the first case, the AffineApplyOp is replaced by a new constant. In the |
287 | /// other cases, it is replaced by its unique operand. |
288 | struct FoldAffineOp : public RewritePattern { |
289 | FoldAffineOp(MLIRContext *context) |
290 | : RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {} |
291 | |
292 | LogicalResult matchAndRewrite(Operation *op, |
293 | PatternRewriter &rewriter) const override { |
294 | auto affineApplyOp = cast<affine::AffineApplyOp>(op); |
295 | auto map = affineApplyOp.getAffineMap(); |
296 | if (map.getNumResults() != 1 || map.getNumInputs() > 1) |
297 | return failure(); |
298 | |
299 | AffineExpr expr = map.getResult(0); |
300 | if (map.getNumInputs() == 0) { |
301 | if (auto val = dyn_cast<AffineConstantExpr>(expr)) { |
302 | rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue()); |
303 | return success(); |
304 | } |
305 | return failure(); |
306 | } |
307 | if (dyn_cast<AffineDimExpr>(Val&: expr) || dyn_cast<AffineSymbolExpr>(Val&: expr)) { |
308 | rewriter.replaceOp(op, newValues: op->getOperand(idx: 0)); |
309 | return success(); |
310 | } |
311 | return failure(); |
312 | } |
313 | }; |
314 | |
315 | template <typename LoopType> |
316 | static void lowerLinalgToLoopsImpl(Operation *enclosingOp) { |
317 | MLIRContext *context = enclosingOp->getContext(); |
318 | RewritePatternSet patterns(context); |
319 | patterns.add<LinalgRewritePattern<LoopType>>(context); |
320 | memref::DimOp::getCanonicalizationPatterns(patterns, context); |
321 | tensor::DimOp::getCanonicalizationPatterns(patterns, context); |
322 | affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context); |
323 | patterns.add<FoldAffineOp>(arg&: context); |
324 | // Just apply the patterns greedily. |
325 | (void)applyPatternsAndFoldGreedily(enclosingOp, std::move(patterns)); |
326 | } |
327 | |
328 | struct LowerToAffineLoops |
329 | : public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> { |
330 | using impl::ConvertLinalgToAffineLoopsPassBase< |
331 | LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase; |
332 | void getDependentDialects(DialectRegistry ®istry) const override { |
333 | registry.insert<memref::MemRefDialect>(); |
334 | } |
335 | void runOnOperation() override { |
336 | lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation()); |
337 | } |
338 | }; |
339 | |
340 | struct LowerToLoops : public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> { |
341 | using impl::ConvertLinalgToLoopsPassBase< |
342 | LowerToLoops>::ConvertLinalgToLoopsPassBase; |
343 | void getDependentDialects(DialectRegistry ®istry) const override { |
344 | registry.insert<memref::MemRefDialect, scf::SCFDialect>(); |
345 | } |
346 | void runOnOperation() override { |
347 | lowerLinalgToLoopsImpl<scf::ForOp>(getOperation()); |
348 | } |
349 | }; |
350 | |
351 | struct LowerToParallelLoops |
352 | : public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> { |
353 | using impl::ConvertLinalgToParallelLoopsPassBase< |
354 | LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase; |
355 | void runOnOperation() override { |
356 | lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation()); |
357 | } |
358 | }; |
359 | |
360 | } // namespace |
361 | |
362 | /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. |
363 | FailureOr<LinalgLoops> |
364 | mlir::linalg::linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp) { |
365 | return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp); |
366 | } |
367 | |
368 | /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. |
369 | FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(RewriterBase &rewriter, |
370 | LinalgOp linalgOp) { |
371 | return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp); |
372 | } |
373 | |
374 | /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. |
375 | FailureOr<LinalgLoops> |
376 | mlir::linalg::linalgOpToParallelLoops(RewriterBase &rewriter, |
377 | LinalgOp linalgOp) { |
378 | return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp); |
379 | } |
380 | |