1//===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Passes.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17#include "mlir/Dialect/Linalg/Utils/Utils.h"
18#include "mlir/Dialect/SCF/Transforms/Transforms.h"
19#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
20#include "mlir/IR/AffineExpr.h"
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/IRMapping.h"
23#include "mlir/Support/LLVM.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Transforms/FoldUtils.h"
26#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27#include "llvm/ADT/TypeSwitch.h"
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS
31#define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS
32#define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS
33#include "mlir/Dialect/Linalg/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37using namespace mlir::linalg;
38
39static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc,
40 AffineMap map,
41 ArrayRef<Value> vals) {
42 if (map.isEmpty())
43 return {};
44
45 assert(map.getNumInputs() == vals.size());
46 SmallVector<Value> res;
47 res.reserve(map.getNumResults());
48 auto dims = map.getNumDims();
49 for (auto e : map.getResults()) {
50 auto exprMap = AffineMap::get(dimCount: dims, symbolCount: map.getNumSymbols(), result: e);
51 SmallVector<Value> operands(vals.begin(), vals.end());
52 affine::canonicalizeMapAndOperands(map: &exprMap, operands: &operands);
53 res.push_back(b.create<affine::AffineApplyOp>(loc, exprMap, operands));
54 }
55 return res;
56}
57
58template <typename LoadOpTy, typename StoreOpTy, typename OpType>
59static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op,
60 ArrayRef<Value> indexedValues,
61 ArrayRef<SmallVector<Value>> indexing,
62 ArrayRef<Value> outputBuffers) {
63 auto &block = op->getRegion(0).front();
64 IRMapping map;
65 map.map(block.getArguments(), indexedValues);
66 for (auto &op : block.without_terminator()) {
67 auto *newOp = b.clone(op, map);
68 map.map(op.getResults(), newOp->getResults());
69 }
70
71 Operation *terminator = block.getTerminator();
72 for (OpOperand &operand : terminator->getOpOperands()) {
73 Value toStore = map.lookupOrDefault(from: operand.get());
74 b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()],
75 indexing[operand.getOperandNumber()]);
76 }
77}
78
79// Returns a pair that contains input indices and output indices of a
80// SingleInputPoolingOp `op`.
81struct InputAndOutputIndices {
82 SmallVector<Value> inputs;
83 SmallVector<Value> outputs;
84};
85template <typename SingleInputPoolingOp>
86static InputAndOutputIndices
87getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs,
88 SingleInputPoolingOp op) {
89 auto mapsRange = op.getIndexingMapsArray();
90 auto maps = llvm::to_vector<8>(
91 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
92 return InputAndOutputIndices{
93 makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
94 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
95}
96
97/// Emits the MLIR for the scalar part of the generic op by:
98/// 1. Emitting load ops for each input and output view in order. This is
99/// achieved by applying the appropriate input or output map to the
100/// enclosing induction variables.
101/// 2. Emitting a call to `op.fun()` that takes as arguments the scalars
102/// from point 1. above.
103/// 3. Emitting store ops to store the results of 2. to the output
104/// views.
105///
106/// An example output may resemble:
107///
108/// ```
109/// scf.for %i = %c0 to %0 step %c1 {
110/// scf.for %j = %c0 to %1 step %c1 {
111/// scf.for %k = %c0 to %4 step %c1 {
112/// %11 = load %arg0[%i, %j] :
113/// memref<?x?xf32, stride_specification>
114/// %12 = load %arg1[%i, %j, %k] :
115/// memref<?x?x?xf32, stride_specification>
116/// %13 = load %arg2[%i, %k, %j] :
117/// memref<?x?x?xf32, stride_specification>
118/// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
119/// store %14#0, %arg1[%i, %j, %k] :
120/// memref<?x?x?Xf32, stride_specification>
121/// store %14#1, %arg2[%i, %k, %j] :
122/// memref<?x?x?Xf32, stride_specification>
123/// }
124/// }
125/// }
126/// ```
127template <typename LoadOpTy, typename StoreOpTy>
128static void emitScalarImplementation(OpBuilder &b, Location loc,
129 ArrayRef<Value> allIvs,
130 LinalgOp linalgOp) {
131 assert(linalgOp.hasPureBufferSemantics() &&
132 "expected linalg op with buffer semantics");
133 SmallVector<Value> indexedValues;
134 indexedValues.reserve(N: linalgOp->getNumOperands());
135
136 auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end());
137
138 // TODO: Avoid the loads if the corresponding argument of the
139 // region has no uses.
140 // 1.a. Emit load from input operand or for scalars access the operand itself.
141 for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
142 if (linalgOp.isScalar(inputOperand)) {
143 indexedValues.push_back(inputOperand->get());
144 continue;
145 }
146 auto indexing = makeCanonicalAffineApplies(
147 b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
148 indexedValues.push_back(
149 b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
150 }
151 // 1.b. Emit load from output views.
152 for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
153 SmallVector<Value> indexing = makeCanonicalAffineApplies(
154 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
155 allIvsPlusDims);
156 indexedValues.push_back(
157 b.create<LoadOpTy>(loc, outputOperand.get(), indexing));
158 }
159
160 // TODO: When a region inliner exists, use it.
161 // 2. Inline region, currently only works for a single basic block.
162 // 3. Emit store.
163 SmallVector<SmallVector<Value>, 8> indexing;
164 SmallVector<Value> outputBuffers;
165 for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
166 if (!isa<MemRefType>(outputOperand.get().getType()))
167 continue;
168 indexing.push_back(makeCanonicalAffineApplies(
169 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
170 allIvsPlusDims));
171 outputBuffers.push_back(outputOperand.get());
172 }
173 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
174 indexing, outputBuffers);
175}
176
177/// Replace the index operations in the body of the loop nest by the matching
178/// induction variables.
179static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
180 LinalgOp linalgOp,
181 ArrayRef<Operation *> loopOps) {
182 // Extract the induction variables of the loop nest from outer to inner.
183 SmallVector<Value> allIvs;
184 for (Operation *loopOp : loopOps) {
185 llvm::TypeSwitch<Operation *>(loopOp)
186 .Case(caseFn: [&](scf::ParallelOp parallelOp) {
187 allIvs.append(parallelOp.getInductionVars().begin(),
188 parallelOp.getInductionVars().end());
189 })
190 .Case(caseFn: [&](scf::ForOp forOp) {
191 allIvs.push_back(Elt: forOp.getInductionVar());
192 })
193 .Case(caseFn: [&](affine::AffineForOp affineForOp) {
194 allIvs.push_back(Elt: affineForOp.getInductionVar());
195 })
196 .Default(defaultFn: [&](Operation *op) { assert(false && "unexpected op"); });
197 }
198 assert(linalgOp.getNumLoops() == allIvs.size() &&
199 "expected the number of loops and induction variables to match");
200 // Replace the index operations in the body of the innermost loop op.
201 if (!loopOps.empty()) {
202 auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
203 for (Region *r : loopOp.getLoopRegions())
204 for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
205 rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
206 }
207}
208
209template <typename LoopTy>
210static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter,
211 LinalgOp linalgOp) {
212 using LoadOpTy =
213 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
214 affine::AffineLoadOp, memref::LoadOp>;
215 using StoreOpTy =
216 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
217 affine::AffineStoreOp, memref::StoreOp>;
218
219 // The flattened loopToOperandRangesMaps is expected to be an invertible
220 // permutation map (which is asserted in the inverse calculation).
221 assert(linalgOp.hasPureBufferSemantics() &&
222 "expected linalg op with buffer semantics");
223
224 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
225 auto iteratorTypes = linalgOp.getIteratorTypesArray();
226
227 SmallVector<Value> allIvs;
228 GenerateLoopNest<LoopTy>::doit(
229 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
230 [&](OpBuilder &b, Location loc, ValueRange ivs,
231 ValueRange operandValuesToUse) -> scf::ValueVector {
232 assert(operandValuesToUse == linalgOp->getOperands() &&
233 "expect operands are captured and not passed by loop argument");
234 allIvs.append(in_start: ivs.begin(), in_end: ivs.end());
235 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
236 return scf::ValueVector{};
237 });
238 // Number of loop ops might be different from the number of ivs since some
239 // loops like affine.parallel and scf.parallel have multiple ivs.
240 SetVector<Operation *> loopSet;
241 for (Value iv : allIvs) {
242 if (!iv)
243 return failure();
244 // The induction variable is a block argument of the entry block of the
245 // loop operation.
246 BlockArgument ivVal = dyn_cast<BlockArgument>(Val&: iv);
247 if (!ivVal)
248 return failure();
249 loopSet.insert(X: ivVal.getOwner()->getParentOp());
250 }
251 LinalgLoops loops(loopSet.begin(), loopSet.end());
252 // Replace all index operations in the loop body.
253 replaceIndexOpsByInductionVariables(rewriter, linalgOp, loops);
254 return loops;
255}
256
257namespace {
258template <typename LoopType>
259class LinalgRewritePattern : public RewritePattern {
260public:
261 LinalgRewritePattern(MLIRContext *context)
262 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
263
264 LogicalResult matchAndRewrite(Operation *op,
265 PatternRewriter &rewriter) const override {
266 auto linalgOp = dyn_cast<LinalgOp>(op);
267 if (!isa<LinalgOp>(Val: op) || !linalgOp.hasPureBufferSemantics()) {
268 return rewriter.notifyMatchFailure(
269 arg&: op, msg: "expected linalg op with buffer semantics");
270 }
271 if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)))
272 return failure();
273 rewriter.eraseOp(op);
274 return success();
275 }
276};
277
278/// Local folding pattern for AffineApplyOp that we can apply greedily.
279/// This replaces AffineApplyOp by the proper value in cases where the
280/// associated map is trivial.
281/// A trivial map here is defined as a map with a single result and either:
282/// 1. Zero operand + returns a single AffineConstantExpr
283/// 2. One operand + returns a single AffineDimExpr
284/// 3. One operand + returns a single AffineSymbolExpr
285//
286/// In the first case, the AffineApplyOp is replaced by a new constant. In the
287/// other cases, it is replaced by its unique operand.
288struct FoldAffineOp : public RewritePattern {
289 FoldAffineOp(MLIRContext *context)
290 : RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {}
291
292 LogicalResult matchAndRewrite(Operation *op,
293 PatternRewriter &rewriter) const override {
294 auto affineApplyOp = cast<affine::AffineApplyOp>(op);
295 auto map = affineApplyOp.getAffineMap();
296 if (map.getNumResults() != 1 || map.getNumInputs() > 1)
297 return failure();
298
299 AffineExpr expr = map.getResult(0);
300 if (map.getNumInputs() == 0) {
301 if (auto val = dyn_cast<AffineConstantExpr>(expr)) {
302 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue());
303 return success();
304 }
305 return failure();
306 }
307 if (dyn_cast<AffineDimExpr>(Val&: expr) || dyn_cast<AffineSymbolExpr>(Val&: expr)) {
308 rewriter.replaceOp(op, newValues: op->getOperand(idx: 0));
309 return success();
310 }
311 return failure();
312 }
313};
314
315template <typename LoopType>
316static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
317 MLIRContext *context = enclosingOp->getContext();
318 RewritePatternSet patterns(context);
319 patterns.add<LinalgRewritePattern<LoopType>>(context);
320 memref::DimOp::getCanonicalizationPatterns(patterns, context);
321 tensor::DimOp::getCanonicalizationPatterns(patterns, context);
322 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
323 patterns.add<FoldAffineOp>(arg&: context);
324 // Just apply the patterns greedily.
325 (void)applyPatternsAndFoldGreedily(enclosingOp, std::move(patterns));
326}
327
328struct LowerToAffineLoops
329 : public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> {
330 using impl::ConvertLinalgToAffineLoopsPassBase<
331 LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase;
332 void getDependentDialects(DialectRegistry &registry) const override {
333 registry.insert<memref::MemRefDialect>();
334 }
335 void runOnOperation() override {
336 lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation());
337 }
338};
339
340struct LowerToLoops : public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> {
341 using impl::ConvertLinalgToLoopsPassBase<
342 LowerToLoops>::ConvertLinalgToLoopsPassBase;
343 void getDependentDialects(DialectRegistry &registry) const override {
344 registry.insert<memref::MemRefDialect, scf::SCFDialect>();
345 }
346 void runOnOperation() override {
347 lowerLinalgToLoopsImpl<scf::ForOp>(getOperation());
348 }
349};
350
351struct LowerToParallelLoops
352 : public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> {
353 using impl::ConvertLinalgToParallelLoopsPassBase<
354 LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase;
355 void runOnOperation() override {
356 lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation());
357 }
358};
359
360} // namespace
361
362/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
363FailureOr<LinalgLoops>
364mlir::linalg::linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp) {
365 return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp);
366}
367
368/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
369FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(RewriterBase &rewriter,
370 LinalgOp linalgOp) {
371 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
372}
373
374/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
375FailureOr<LinalgLoops>
376mlir::linalg::linalgOpToParallelLoops(RewriterBase &rewriter,
377 LinalgOp linalgOp) {
378 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
379}
380

source code of mlir/lib/Dialect/Linalg/Transforms/Loops.cpp