| 1 | //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Linalg/Passes.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 13 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 16 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 17 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 18 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
| 19 | #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" |
| 20 | #include "mlir/IR/AffineExpr.h" |
| 21 | #include "mlir/IR/AffineMap.h" |
| 22 | #include "mlir/IR/IRMapping.h" |
| 23 | #include "mlir/Support/LLVM.h" |
| 24 | #include "mlir/Transforms/DialectConversion.h" |
| 25 | #include "mlir/Transforms/FoldUtils.h" |
| 26 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 27 | #include "llvm/ADT/TypeSwitch.h" |
| 28 | |
| 29 | namespace mlir { |
| 30 | #define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS |
| 31 | #define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS |
| 32 | #define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS |
| 33 | #include "mlir/Dialect/Linalg/Passes.h.inc" |
| 34 | } // namespace mlir |
| 35 | |
| 36 | using namespace mlir; |
| 37 | using namespace mlir::linalg; |
| 38 | |
| 39 | static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc, |
| 40 | AffineMap map, |
| 41 | ArrayRef<Value> vals) { |
| 42 | if (map.isEmpty()) |
| 43 | return {}; |
| 44 | |
| 45 | assert(map.getNumInputs() == vals.size()); |
| 46 | SmallVector<Value> res; |
| 47 | res.reserve(map.getNumResults()); |
| 48 | auto dims = map.getNumDims(); |
| 49 | for (auto e : map.getResults()) { |
| 50 | auto exprMap = AffineMap::get(dimCount: dims, symbolCount: map.getNumSymbols(), result: e); |
| 51 | SmallVector<Value> operands(vals); |
| 52 | affine::canonicalizeMapAndOperands(map: &exprMap, operands: &operands); |
| 53 | res.push_back(b.create<affine::AffineApplyOp>(loc, exprMap, operands)); |
| 54 | } |
| 55 | return res; |
| 56 | } |
| 57 | |
| 58 | template <typename LoadOpTy, typename StoreOpTy, typename OpType> |
| 59 | static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, |
| 60 | ArrayRef<Value> indexedValues, |
| 61 | ArrayRef<SmallVector<Value>> indexing, |
| 62 | ArrayRef<Value> outputBuffers) { |
| 63 | auto &block = op->getRegion(0).front(); |
| 64 | IRMapping map; |
| 65 | map.map(block.getArguments(), indexedValues); |
| 66 | for (auto &op : block.without_terminator()) { |
| 67 | auto *newOp = b.clone(op, map); |
| 68 | map.map(op.getResults(), newOp->getResults()); |
| 69 | } |
| 70 | |
| 71 | Operation *terminator = block.getTerminator(); |
| 72 | for (OpOperand &operand : terminator->getOpOperands()) { |
| 73 | Value toStore = map.lookupOrDefault(from: operand.get()); |
| 74 | b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()], |
| 75 | indexing[operand.getOperandNumber()]); |
| 76 | } |
| 77 | } |
| 78 | |
| 79 | // Returns a pair that contains input indices and output indices of a |
| 80 | // SingleInputPoolingOp `op`. |
| 81 | struct InputAndOutputIndices { |
| 82 | SmallVector<Value> inputs; |
| 83 | SmallVector<Value> outputs; |
| 84 | }; |
| 85 | template <typename SingleInputPoolingOp> |
| 86 | static InputAndOutputIndices |
| 87 | getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs, |
| 88 | SingleInputPoolingOp op) { |
| 89 | auto mapsRange = op.getIndexingMapsArray(); |
| 90 | auto maps = llvm::to_vector<8>( |
| 91 | llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); |
| 92 | return InputAndOutputIndices{ |
| 93 | makeCanonicalAffineApplies(b, loc, maps[0], allIvs), |
| 94 | makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; |
| 95 | } |
| 96 | |
| 97 | /// Emits the MLIR for the scalar part of the generic op by: |
| 98 | /// 1. Emitting load ops for each input and output view in order. This is |
| 99 | /// achieved by applying the appropriate input or output map to the |
| 100 | /// enclosing induction variables. |
| 101 | /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars |
| 102 | /// from point 1. above. |
| 103 | /// 3. Emitting store ops to store the results of 2. to the output |
| 104 | /// views. |
| 105 | /// |
| 106 | /// An example output may resemble: |
| 107 | /// |
| 108 | /// ``` |
| 109 | /// scf.for %i = %c0 to %0 step %c1 { |
| 110 | /// scf.for %j = %c0 to %1 step %c1 { |
| 111 | /// scf.for %k = %c0 to %4 step %c1 { |
| 112 | /// %11 = load %arg0[%i, %j] : |
| 113 | /// memref<?x?xf32, stride_specification> |
| 114 | /// %12 = load %arg1[%i, %j, %k] : |
| 115 | /// memref<?x?x?xf32, stride_specification> |
| 116 | /// %13 = load %arg2[%i, %k, %j] : |
| 117 | /// memref<?x?x?xf32, stride_specification> |
| 118 | /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) |
| 119 | /// store %14#0, %arg1[%i, %j, %k] : |
| 120 | /// memref<?x?x?Xf32, stride_specification> |
| 121 | /// store %14#1, %arg2[%i, %k, %j] : |
| 122 | /// memref<?x?x?Xf32, stride_specification> |
| 123 | /// } |
| 124 | /// } |
| 125 | /// } |
| 126 | /// ``` |
| 127 | template <typename LoadOpTy, typename StoreOpTy> |
| 128 | static void emitScalarImplementation(OpBuilder &b, Location loc, |
| 129 | ArrayRef<Value> allIvs, |
| 130 | LinalgOp linalgOp) { |
| 131 | assert(linalgOp.hasPureBufferSemantics() && |
| 132 | "expected linalg op with buffer semantics" ); |
| 133 | SmallVector<Value> indexedValues; |
| 134 | indexedValues.reserve(N: linalgOp->getNumOperands()); |
| 135 | |
| 136 | auto allIvsPlusDims = SmallVector<Value>(allIvs); |
| 137 | |
| 138 | // TODO: Avoid the loads if the corresponding argument of the |
| 139 | // region has no uses. |
| 140 | // 1.a. Emit load from input operand or for scalars access the operand itself. |
| 141 | for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) { |
| 142 | if (linalgOp.isScalar(inputOperand)) { |
| 143 | indexedValues.push_back(inputOperand->get()); |
| 144 | continue; |
| 145 | } |
| 146 | auto indexing = makeCanonicalAffineApplies( |
| 147 | b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims); |
| 148 | indexedValues.push_back( |
| 149 | b.create<LoadOpTy>(loc, inputOperand->get(), indexing)); |
| 150 | } |
| 151 | // 1.b. Emit load from output views. |
| 152 | for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) { |
| 153 | SmallVector<Value> indexing = makeCanonicalAffineApplies( |
| 154 | b, loc, linalgOp.getMatchingIndexingMap(&outputOperand), |
| 155 | allIvsPlusDims); |
| 156 | indexedValues.push_back( |
| 157 | b.create<LoadOpTy>(loc, outputOperand.get(), indexing)); |
| 158 | } |
| 159 | |
| 160 | // TODO: When a region inliner exists, use it. |
| 161 | // 2. Inline region, currently only works for a single basic block. |
| 162 | // 3. Emit store. |
| 163 | SmallVector<SmallVector<Value>, 8> indexing; |
| 164 | SmallVector<Value> outputBuffers; |
| 165 | for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) { |
| 166 | if (!isa<MemRefType>(outputOperand.get().getType())) |
| 167 | continue; |
| 168 | indexing.push_back(makeCanonicalAffineApplies( |
| 169 | b, loc, linalgOp.getMatchingIndexingMap(&outputOperand), |
| 170 | allIvsPlusDims)); |
| 171 | outputBuffers.push_back(outputOperand.get()); |
| 172 | } |
| 173 | inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues, |
| 174 | indexing, outputBuffers); |
| 175 | } |
| 176 | |
| 177 | /// Replace the index operations in the body of the loop nest by the matching |
| 178 | /// induction variables. |
| 179 | static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, |
| 180 | LinalgOp linalgOp, |
| 181 | ArrayRef<Operation *> loopOps) { |
| 182 | // Extract the induction variables of the loop nest from outer to inner. |
| 183 | SmallVector<Value> allIvs; |
| 184 | for (Operation *loopOp : loopOps) { |
| 185 | llvm::TypeSwitch<Operation *>(loopOp) |
| 186 | .Case(caseFn: [&](scf::ParallelOp parallelOp) { |
| 187 | allIvs.append(parallelOp.getInductionVars()); |
| 188 | }) |
| 189 | .Case(caseFn: [&](scf::ForOp forOp) { |
| 190 | allIvs.push_back(Elt: forOp.getInductionVar()); |
| 191 | }) |
| 192 | .Case(caseFn: [&](affine::AffineForOp affineForOp) { |
| 193 | allIvs.push_back(Elt: affineForOp.getInductionVar()); |
| 194 | }) |
| 195 | .Default(defaultFn: [&](Operation *op) { assert(false && "unexpected op" ); }); |
| 196 | } |
| 197 | assert(linalgOp.getNumLoops() == allIvs.size() && |
| 198 | "expected the number of loops and induction variables to match" ); |
| 199 | // Replace the index operations in the body of the innermost loop op. |
| 200 | if (!loopOps.empty()) { |
| 201 | auto loopOp = cast<LoopLikeOpInterface>(loopOps.back()); |
| 202 | for (Region *r : loopOp.getLoopRegions()) |
| 203 | for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>())) |
| 204 | rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]); |
| 205 | } |
| 206 | } |
| 207 | |
| 208 | template <typename LoopTy> |
| 209 | static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter, |
| 210 | LinalgOp linalgOp) { |
| 211 | using LoadOpTy = |
| 212 | std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value, |
| 213 | affine::AffineLoadOp, memref::LoadOp>; |
| 214 | using StoreOpTy = |
| 215 | std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value, |
| 216 | affine::AffineStoreOp, memref::StoreOp>; |
| 217 | |
| 218 | // The flattened loopToOperandRangesMaps is expected to be an invertible |
| 219 | // permutation map (which is asserted in the inverse calculation). |
| 220 | assert(linalgOp.hasPureBufferSemantics() && |
| 221 | "expected linalg op with buffer semantics" ); |
| 222 | |
| 223 | auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); |
| 224 | auto iteratorTypes = linalgOp.getIteratorTypesArray(); |
| 225 | |
| 226 | SmallVector<Value> allIvs; |
| 227 | GenerateLoopNest<LoopTy>::doit( |
| 228 | rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, |
| 229 | [&](OpBuilder &b, Location loc, ValueRange ivs, |
| 230 | ValueRange operandValuesToUse) -> scf::ValueVector { |
| 231 | assert(operandValuesToUse == linalgOp->getOperands() && |
| 232 | "expect operands are captured and not passed by loop argument" ); |
| 233 | allIvs.append(in_start: ivs.begin(), in_end: ivs.end()); |
| 234 | emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp); |
| 235 | return scf::ValueVector{}; |
| 236 | }); |
| 237 | // Number of loop ops might be different from the number of ivs since some |
| 238 | // loops like affine.parallel and scf.parallel have multiple ivs. |
| 239 | SetVector<Operation *> loopSet; |
| 240 | for (Value iv : allIvs) { |
| 241 | if (!iv) |
| 242 | return failure(); |
| 243 | // The induction variable is a block argument of the entry block of the |
| 244 | // loop operation. |
| 245 | BlockArgument ivVal = dyn_cast<BlockArgument>(Val&: iv); |
| 246 | if (!ivVal) |
| 247 | return failure(); |
| 248 | loopSet.insert(X: ivVal.getOwner()->getParentOp()); |
| 249 | } |
| 250 | LinalgLoops loops(loopSet.begin(), loopSet.end()); |
| 251 | // Replace all index operations in the loop body. |
| 252 | replaceIndexOpsByInductionVariables(rewriter, linalgOp, loops); |
| 253 | return loops; |
| 254 | } |
| 255 | |
| 256 | namespace { |
| 257 | template <typename LoopType> |
| 258 | class LinalgRewritePattern : public RewritePattern { |
| 259 | public: |
| 260 | LinalgRewritePattern(MLIRContext *context) |
| 261 | : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} |
| 262 | |
| 263 | LogicalResult matchAndRewrite(Operation *op, |
| 264 | PatternRewriter &rewriter) const override { |
| 265 | auto linalgOp = dyn_cast<LinalgOp>(op); |
| 266 | if (!isa<LinalgOp>(Val: op) || !linalgOp.hasPureBufferSemantics()) { |
| 267 | return rewriter.notifyMatchFailure( |
| 268 | arg&: op, msg: "expected linalg op with buffer semantics" ); |
| 269 | } |
| 270 | if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp))) |
| 271 | return failure(); |
| 272 | rewriter.eraseOp(op); |
| 273 | return success(); |
| 274 | } |
| 275 | }; |
| 276 | |
| 277 | /// Local folding pattern for AffineApplyOp that we can apply greedily. |
| 278 | /// This replaces AffineApplyOp by the proper value in cases where the |
| 279 | /// associated map is trivial. |
| 280 | /// A trivial map here is defined as a map with a single result and either: |
| 281 | /// 1. Zero operand + returns a single AffineConstantExpr |
| 282 | /// 2. One operand + returns a single AffineDimExpr |
| 283 | /// 3. One operand + returns a single AffineSymbolExpr |
| 284 | // |
| 285 | /// In the first case, the AffineApplyOp is replaced by a new constant. In the |
| 286 | /// other cases, it is replaced by its unique operand. |
| 287 | struct FoldAffineOp : public RewritePattern { |
| 288 | FoldAffineOp(MLIRContext *context) |
| 289 | : RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {} |
| 290 | |
| 291 | LogicalResult matchAndRewrite(Operation *op, |
| 292 | PatternRewriter &rewriter) const override { |
| 293 | auto affineApplyOp = cast<affine::AffineApplyOp>(op); |
| 294 | auto map = affineApplyOp.getAffineMap(); |
| 295 | if (map.getNumResults() != 1 || map.getNumInputs() > 1) |
| 296 | return failure(); |
| 297 | |
| 298 | AffineExpr expr = map.getResult(0); |
| 299 | if (map.getNumInputs() == 0) { |
| 300 | if (auto val = dyn_cast<AffineConstantExpr>(expr)) { |
| 301 | rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue()); |
| 302 | return success(); |
| 303 | } |
| 304 | return failure(); |
| 305 | } |
| 306 | if (isa<AffineDimExpr, AffineSymbolExpr>(Val: expr)) { |
| 307 | rewriter.replaceOp(op, newValues: op->getOperand(idx: 0)); |
| 308 | return success(); |
| 309 | } |
| 310 | return failure(); |
| 311 | } |
| 312 | }; |
| 313 | |
| 314 | template <typename LoopType> |
| 315 | static void lowerLinalgToLoopsImpl(Operation *enclosingOp) { |
| 316 | MLIRContext *context = enclosingOp->getContext(); |
| 317 | RewritePatternSet patterns(context); |
| 318 | patterns.add<LinalgRewritePattern<LoopType>>(context); |
| 319 | memref::DimOp::getCanonicalizationPatterns(patterns, context); |
| 320 | tensor::DimOp::getCanonicalizationPatterns(patterns, context); |
| 321 | affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context); |
| 322 | patterns.add<FoldAffineOp>(arg&: context); |
| 323 | // Just apply the patterns greedily. |
| 324 | (void)applyPatternsGreedily(enclosingOp, std::move(patterns)); |
| 325 | } |
| 326 | |
| 327 | struct LowerToAffineLoops |
| 328 | : public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> { |
| 329 | using impl::ConvertLinalgToAffineLoopsPassBase< |
| 330 | LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase; |
| 331 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 332 | registry.insert<memref::MemRefDialect>(); |
| 333 | } |
| 334 | void runOnOperation() override { |
| 335 | lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation()); |
| 336 | } |
| 337 | }; |
| 338 | |
| 339 | struct LowerToLoops : public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> { |
| 340 | using impl::ConvertLinalgToLoopsPassBase< |
| 341 | LowerToLoops>::ConvertLinalgToLoopsPassBase; |
| 342 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 343 | registry.insert<memref::MemRefDialect, scf::SCFDialect>(); |
| 344 | } |
| 345 | void runOnOperation() override { |
| 346 | lowerLinalgToLoopsImpl<scf::ForOp>(getOperation()); |
| 347 | } |
| 348 | }; |
| 349 | |
| 350 | struct LowerToParallelLoops |
| 351 | : public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> { |
| 352 | using impl::ConvertLinalgToParallelLoopsPassBase< |
| 353 | LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase; |
| 354 | void runOnOperation() override { |
| 355 | lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation()); |
| 356 | } |
| 357 | }; |
| 358 | |
| 359 | } // namespace |
| 360 | |
| 361 | /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. |
| 362 | FailureOr<LinalgLoops> |
| 363 | mlir::linalg::linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp) { |
| 364 | return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp); |
| 365 | } |
| 366 | |
| 367 | /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. |
| 368 | FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(RewriterBase &rewriter, |
| 369 | LinalgOp linalgOp) { |
| 370 | return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp); |
| 371 | } |
| 372 | |
| 373 | /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. |
| 374 | FailureOr<LinalgLoops> |
| 375 | mlir::linalg::linalgOpToParallelLoops(RewriterBase &rewriter, |
| 376 | LinalgOp linalgOp) { |
| 377 | return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp); |
| 378 | } |
| 379 | |