1//===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Passes.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17#include "mlir/Dialect/Linalg/Utils/Utils.h"
18#include "mlir/Dialect/SCF/Transforms/Transforms.h"
19#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
20#include "mlir/IR/AffineExpr.h"
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/IRMapping.h"
23#include "mlir/Support/LLVM.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Transforms/FoldUtils.h"
26#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27#include "llvm/ADT/TypeSwitch.h"
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS
31#define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS
32#define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS
33#include "mlir/Dialect/Linalg/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37using namespace mlir::linalg;
38
39static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc,
40 AffineMap map,
41 ArrayRef<Value> vals) {
42 if (map.isEmpty())
43 return {};
44
45 assert(map.getNumInputs() == vals.size());
46 SmallVector<Value> res;
47 res.reserve(map.getNumResults());
48 auto dims = map.getNumDims();
49 for (auto e : map.getResults()) {
50 auto exprMap = AffineMap::get(dimCount: dims, symbolCount: map.getNumSymbols(), result: e);
51 SmallVector<Value> operands(vals);
52 affine::canonicalizeMapAndOperands(map: &exprMap, operands: &operands);
53 res.push_back(b.create<affine::AffineApplyOp>(loc, exprMap, operands));
54 }
55 return res;
56}
57
58template <typename LoadOpTy, typename StoreOpTy, typename OpType>
59static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op,
60 ArrayRef<Value> indexedValues,
61 ArrayRef<SmallVector<Value>> indexing,
62 ArrayRef<Value> outputBuffers) {
63 auto &block = op->getRegion(0).front();
64 IRMapping map;
65 map.map(block.getArguments(), indexedValues);
66 for (auto &op : block.without_terminator()) {
67 auto *newOp = b.clone(op, map);
68 map.map(op.getResults(), newOp->getResults());
69 }
70
71 Operation *terminator = block.getTerminator();
72 for (OpOperand &operand : terminator->getOpOperands()) {
73 Value toStore = map.lookupOrDefault(from: operand.get());
74 b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()],
75 indexing[operand.getOperandNumber()]);
76 }
77}
78
79// Returns a pair that contains input indices and output indices of a
80// SingleInputPoolingOp `op`.
81struct InputAndOutputIndices {
82 SmallVector<Value> inputs;
83 SmallVector<Value> outputs;
84};
85template <typename SingleInputPoolingOp>
86static InputAndOutputIndices
87getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs,
88 SingleInputPoolingOp op) {
89 auto mapsRange = op.getIndexingMapsArray();
90 auto maps = llvm::to_vector<8>(
91 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
92 return InputAndOutputIndices{
93 makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
94 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
95}
96
97/// Emits the MLIR for the scalar part of the generic op by:
98/// 1. Emitting load ops for each input and output view in order. This is
99/// achieved by applying the appropriate input or output map to the
100/// enclosing induction variables.
101/// 2. Emitting a call to `op.fun()` that takes as arguments the scalars
102/// from point 1. above.
103/// 3. Emitting store ops to store the results of 2. to the output
104/// views.
105///
106/// An example output may resemble:
107///
108/// ```
109/// scf.for %i = %c0 to %0 step %c1 {
110/// scf.for %j = %c0 to %1 step %c1 {
111/// scf.for %k = %c0 to %4 step %c1 {
112/// %11 = load %arg0[%i, %j] :
113/// memref<?x?xf32, stride_specification>
114/// %12 = load %arg1[%i, %j, %k] :
115/// memref<?x?x?xf32, stride_specification>
116/// %13 = load %arg2[%i, %k, %j] :
117/// memref<?x?x?xf32, stride_specification>
118/// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
119/// store %14#0, %arg1[%i, %j, %k] :
120/// memref<?x?x?Xf32, stride_specification>
121/// store %14#1, %arg2[%i, %k, %j] :
122/// memref<?x?x?Xf32, stride_specification>
123/// }
124/// }
125/// }
126/// ```
127template <typename LoadOpTy, typename StoreOpTy>
128static void emitScalarImplementation(OpBuilder &b, Location loc,
129 ArrayRef<Value> allIvs,
130 LinalgOp linalgOp) {
131 assert(linalgOp.hasPureBufferSemantics() &&
132 "expected linalg op with buffer semantics");
133 SmallVector<Value> indexedValues;
134 indexedValues.reserve(N: linalgOp->getNumOperands());
135
136 auto allIvsPlusDims = SmallVector<Value>(allIvs);
137
138 // TODO: Avoid the loads if the corresponding argument of the
139 // region has no uses.
140 // 1.a. Emit load from input operand or for scalars access the operand itself.
141 for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
142 if (linalgOp.isScalar(inputOperand)) {
143 indexedValues.push_back(inputOperand->get());
144 continue;
145 }
146 auto indexing = makeCanonicalAffineApplies(
147 b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
148 indexedValues.push_back(
149 b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
150 }
151 // 1.b. Emit load from output views.
152 for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
153 SmallVector<Value> indexing = makeCanonicalAffineApplies(
154 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
155 allIvsPlusDims);
156 indexedValues.push_back(
157 b.create<LoadOpTy>(loc, outputOperand.get(), indexing));
158 }
159
160 // TODO: When a region inliner exists, use it.
161 // 2. Inline region, currently only works for a single basic block.
162 // 3. Emit store.
163 SmallVector<SmallVector<Value>, 8> indexing;
164 SmallVector<Value> outputBuffers;
165 for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
166 if (!isa<MemRefType>(outputOperand.get().getType()))
167 continue;
168 indexing.push_back(makeCanonicalAffineApplies(
169 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
170 allIvsPlusDims));
171 outputBuffers.push_back(outputOperand.get());
172 }
173 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
174 indexing, outputBuffers);
175}
176
177/// Replace the index operations in the body of the loop nest by the matching
178/// induction variables.
179static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
180 LinalgOp linalgOp,
181 ArrayRef<Operation *> loopOps) {
182 // Extract the induction variables of the loop nest from outer to inner.
183 SmallVector<Value> allIvs;
184 for (Operation *loopOp : loopOps) {
185 llvm::TypeSwitch<Operation *>(loopOp)
186 .Case(caseFn: [&](scf::ParallelOp parallelOp) {
187 allIvs.append(parallelOp.getInductionVars());
188 })
189 .Case(caseFn: [&](scf::ForOp forOp) {
190 allIvs.push_back(Elt: forOp.getInductionVar());
191 })
192 .Case(caseFn: [&](affine::AffineForOp affineForOp) {
193 allIvs.push_back(Elt: affineForOp.getInductionVar());
194 })
195 .Default(defaultFn: [&](Operation *op) { assert(false && "unexpected op"); });
196 }
197 assert(linalgOp.getNumLoops() == allIvs.size() &&
198 "expected the number of loops and induction variables to match");
199 // Replace the index operations in the body of the innermost loop op.
200 if (!loopOps.empty()) {
201 auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
202 for (Region *r : loopOp.getLoopRegions())
203 for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
204 rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
205 }
206}
207
208template <typename LoopTy>
209static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter,
210 LinalgOp linalgOp) {
211 using LoadOpTy =
212 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
213 affine::AffineLoadOp, memref::LoadOp>;
214 using StoreOpTy =
215 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
216 affine::AffineStoreOp, memref::StoreOp>;
217
218 // The flattened loopToOperandRangesMaps is expected to be an invertible
219 // permutation map (which is asserted in the inverse calculation).
220 assert(linalgOp.hasPureBufferSemantics() &&
221 "expected linalg op with buffer semantics");
222
223 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
224 auto iteratorTypes = linalgOp.getIteratorTypesArray();
225
226 SmallVector<Value> allIvs;
227 GenerateLoopNest<LoopTy>::doit(
228 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
229 [&](OpBuilder &b, Location loc, ValueRange ivs,
230 ValueRange operandValuesToUse) -> scf::ValueVector {
231 assert(operandValuesToUse == linalgOp->getOperands() &&
232 "expect operands are captured and not passed by loop argument");
233 allIvs.append(in_start: ivs.begin(), in_end: ivs.end());
234 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
235 return scf::ValueVector{};
236 });
237 // Number of loop ops might be different from the number of ivs since some
238 // loops like affine.parallel and scf.parallel have multiple ivs.
239 SetVector<Operation *> loopSet;
240 for (Value iv : allIvs) {
241 if (!iv)
242 return failure();
243 // The induction variable is a block argument of the entry block of the
244 // loop operation.
245 BlockArgument ivVal = dyn_cast<BlockArgument>(Val&: iv);
246 if (!ivVal)
247 return failure();
248 loopSet.insert(X: ivVal.getOwner()->getParentOp());
249 }
250 LinalgLoops loops(loopSet.begin(), loopSet.end());
251 // Replace all index operations in the loop body.
252 replaceIndexOpsByInductionVariables(rewriter, linalgOp, loops);
253 return loops;
254}
255
256namespace {
257template <typename LoopType>
258class LinalgRewritePattern : public RewritePattern {
259public:
260 LinalgRewritePattern(MLIRContext *context)
261 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
262
263 LogicalResult matchAndRewrite(Operation *op,
264 PatternRewriter &rewriter) const override {
265 auto linalgOp = dyn_cast<LinalgOp>(op);
266 if (!isa<LinalgOp>(Val: op) || !linalgOp.hasPureBufferSemantics()) {
267 return rewriter.notifyMatchFailure(
268 arg&: op, msg: "expected linalg op with buffer semantics");
269 }
270 if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)))
271 return failure();
272 rewriter.eraseOp(op);
273 return success();
274 }
275};
276
277/// Local folding pattern for AffineApplyOp that we can apply greedily.
278/// This replaces AffineApplyOp by the proper value in cases where the
279/// associated map is trivial.
280/// A trivial map here is defined as a map with a single result and either:
281/// 1. Zero operand + returns a single AffineConstantExpr
282/// 2. One operand + returns a single AffineDimExpr
283/// 3. One operand + returns a single AffineSymbolExpr
284//
285/// In the first case, the AffineApplyOp is replaced by a new constant. In the
286/// other cases, it is replaced by its unique operand.
287struct FoldAffineOp : public RewritePattern {
288 FoldAffineOp(MLIRContext *context)
289 : RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {}
290
291 LogicalResult matchAndRewrite(Operation *op,
292 PatternRewriter &rewriter) const override {
293 auto affineApplyOp = cast<affine::AffineApplyOp>(op);
294 auto map = affineApplyOp.getAffineMap();
295 if (map.getNumResults() != 1 || map.getNumInputs() > 1)
296 return failure();
297
298 AffineExpr expr = map.getResult(0);
299 if (map.getNumInputs() == 0) {
300 if (auto val = dyn_cast<AffineConstantExpr>(expr)) {
301 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue());
302 return success();
303 }
304 return failure();
305 }
306 if (isa<AffineDimExpr, AffineSymbolExpr>(Val: expr)) {
307 rewriter.replaceOp(op, newValues: op->getOperand(idx: 0));
308 return success();
309 }
310 return failure();
311 }
312};
313
314template <typename LoopType>
315static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
316 MLIRContext *context = enclosingOp->getContext();
317 RewritePatternSet patterns(context);
318 patterns.add<LinalgRewritePattern<LoopType>>(context);
319 memref::DimOp::getCanonicalizationPatterns(patterns, context);
320 tensor::DimOp::getCanonicalizationPatterns(patterns, context);
321 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
322 patterns.add<FoldAffineOp>(arg&: context);
323 // Just apply the patterns greedily.
324 (void)applyPatternsGreedily(enclosingOp, std::move(patterns));
325}
326
327struct LowerToAffineLoops
328 : public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> {
329 using impl::ConvertLinalgToAffineLoopsPassBase<
330 LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase;
331 void getDependentDialects(DialectRegistry &registry) const override {
332 registry.insert<memref::MemRefDialect>();
333 }
334 void runOnOperation() override {
335 lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation());
336 }
337};
338
339struct LowerToLoops : public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> {
340 using impl::ConvertLinalgToLoopsPassBase<
341 LowerToLoops>::ConvertLinalgToLoopsPassBase;
342 void getDependentDialects(DialectRegistry &registry) const override {
343 registry.insert<memref::MemRefDialect, scf::SCFDialect>();
344 }
345 void runOnOperation() override {
346 lowerLinalgToLoopsImpl<scf::ForOp>(getOperation());
347 }
348};
349
350struct LowerToParallelLoops
351 : public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> {
352 using impl::ConvertLinalgToParallelLoopsPassBase<
353 LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase;
354 void runOnOperation() override {
355 lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation());
356 }
357};
358
359} // namespace
360
361/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
362FailureOr<LinalgLoops>
363mlir::linalg::linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp) {
364 return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp);
365}
366
367/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
368FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(RewriterBase &rewriter,
369 LinalgOp linalgOp) {
370 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
371}
372
373/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
374FailureOr<LinalgLoops>
375mlir::linalg::linalgOpToParallelLoops(RewriterBase &rewriter,
376 LinalgOp linalgOp) {
377 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
378}
379

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/Loops.cpp