| 1 | //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 12 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 13 | #include <optional> |
| 14 | |
| 15 | using namespace mlir; |
| 16 | using namespace mlir::linalg; |
| 17 | |
| 18 | namespace { |
| 19 | |
| 20 | /// Pattern to decompose a GenericOp that has more than two statements |
| 21 | /// into one GenericOp with the first statement (i.e. peeled operation), and |
| 22 | /// a second GenericOp with the remaining statements (i.e. residual operations). |
| 23 | |
| 24 | /// - The result of the first GenericOp has the same shape as the iteration |
| 25 | /// space of the GenericOp. The body of the op yields as many values as the |
| 26 | /// original op plus all the results of the peeled operation. |
| 27 | /// - The second GenericOp has as many operands as the original operation plus |
| 28 | /// all the results of the first Generic Op. It has the same number of yields as |
| 29 | /// the original op. |
| 30 | /// - If the result of the peeled operation was yielded by the original |
| 31 | /// GenericOp the uses of the corresponding results will be replaced with the |
| 32 | /// result of the first GenericOp created. |
| 33 | /// |
| 34 | /// Example |
| 35 | /// |
| 36 | /// ```mlir |
| 37 | /// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) |
| 38 | /// outs(%init0, %init1 : ...) { |
| 39 | /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...): |
| 40 | /// %0 = <s0> %b0, %b1 : ... |
| 41 | /// %1 = <s1> %0, %b2 : ... |
| 42 | /// linalg.yield %0, %1 : ... |
| 43 | /// } -> (..., ...) |
| 44 | /// return %result#0, %result#1 |
| 45 | /// ``` |
| 46 | /// |
| 47 | /// gets split into |
| 48 | /// |
| 49 | /// ```mlir |
| 50 | /// %init = tensor.empty ... |
| 51 | /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) |
| 52 | /// outs(%init0, %init1, %init : ...) |
| 53 | /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): |
| 54 | /// %0 = <s0> %b0, %b1 : ... |
| 55 | /// linalg.yield %0, %..., %0 : ... |
| 56 | /// } -> (..., ..., ...) |
| 57 | /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...) |
| 58 | /// outs(%init0, %init1 : ...) { |
| 59 | /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): |
| 60 | /// %1 = <s1> %b3, %b2 : ... |
| 61 | /// linalg.yield %..., %1 : ... |
| 62 | /// } -> (..., ...) |
| 63 | /// return %op0#0, %op1#1 |
| 64 | /// ``` |
| 65 | /// |
| 66 | /// After canonicalization this is expected to be |
| 67 | /// |
| 68 | /// ```mlir |
| 69 | /// %init = tensor.empty ... |
| 70 | /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...) |
| 71 | /// outs(%init : ...) |
| 72 | /// ^bb0(%b0: ... , %b1: ... , %b2: ...): |
| 73 | /// %0 = <s0> %b0, %b1 : ... |
| 74 | /// linalg.yield %0 : ... |
| 75 | /// } -> ... |
| 76 | /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...) |
| 77 | /// outs(%init1 : ...) { |
| 78 | /// ^bb0(%b0: ... , %b1: ... , %b2: ...): |
| 79 | /// %1 = <s1> %b1, %b0 : ... |
| 80 | /// linalg.yield %..., %1 : ... |
| 81 | /// } -> ... |
| 82 | /// return %op0, %op1 |
| 83 | /// ``` |
| 84 | struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> { |
| 85 | using OpRewritePattern<GenericOp>::OpRewritePattern; |
| 86 | |
| 87 | LogicalResult matchAndRewrite(GenericOp genericOp, |
| 88 | PatternRewriter &rewriter) const override; |
| 89 | |
| 90 | private: |
| 91 | /// Helper method to create a generic op for the peeled scalar operation. The |
| 92 | /// created op has an empty region. |
| 93 | GenericOp createPeeledGenericOp(GenericOp genericOp, |
| 94 | PatternRewriter &rewriter) const; |
| 95 | |
| 96 | /// Helper method to create a generic op for the residual scalar operation. |
| 97 | /// The created op has the same region as the original op. |
| 98 | GenericOp createResidualGenericOp(GenericOp genericOp, |
| 99 | GenericOp peeledGenericOp, |
| 100 | PatternRewriter &rewriter) const; |
| 101 | }; |
| 102 | } // namespace |
| 103 | |
| 104 | /// Helper method to compute the range of a generic op. |
| 105 | static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b, |
| 106 | GenericOp op) { |
| 107 | OpBuilder::InsertionGuard g(b); |
| 108 | b.setInsertionPoint(op); |
| 109 | Location loc = op.getLoc(); |
| 110 | auto allShapesSizes = |
| 111 | cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc); |
| 112 | AffineMap map = op.getShapesToLoopsMap(); |
| 113 | IRRewriter rewriter(b); |
| 114 | return affine::makeComposedFoldedMultiResultAffineApply(b&: rewriter, loc, map, |
| 115 | operands: allShapesSizes); |
| 116 | } |
| 117 | |
| 118 | /// Helper method to permute the list of `values` based on the `map`. |
| 119 | SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values, |
| 120 | AffineMap map) { |
| 121 | assert(map.isPermutation()); |
| 122 | SmallVector<OpFoldResult> permutedValues(values.size()); |
| 123 | for (const auto &position : |
| 124 | llvm::enumerate(First: llvm::map_range(C: map.getResults(), F: [](AffineExpr expr) { |
| 125 | return cast<AffineDimExpr>(Val&: expr).getPosition(); |
| 126 | }))) |
| 127 | permutedValues[position.value()] = values[position.index()]; |
| 128 | return permutedValues; |
| 129 | } |
| 130 | |
| 131 | /// Get zero value for an element type. |
| 132 | static Value getZero(OpBuilder &b, Location loc, Type elementType) { |
| 133 | assert(elementType.isIntOrIndexOrFloat() && |
| 134 | "expected scalar type while computing zero value" ); |
| 135 | if (isa<IntegerType>(Val: elementType)) |
| 136 | return b.create<arith::ConstantIntOp>(location: loc, args: 0, args&: elementType); |
| 137 | if (elementType.isIndex()) |
| 138 | return b.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 139 | // Assume float. |
| 140 | auto floatType = cast<FloatType>(elementType); |
| 141 | return b.create<arith::ConstantFloatOp>( |
| 142 | loc, APFloat::getZero(Sem: floatType.getFloatSemantics()), floatType); |
| 143 | } |
| 144 | |
| 145 | GenericOp |
| 146 | DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, |
| 147 | PatternRewriter &rewriter) const { |
| 148 | Block *body = genericOp.getBody(); |
| 149 | Operation *peeledScalarOperation = &(*body->begin()); |
| 150 | SmallVector<AffineMap> peeledGenericOpIndexingMaps = |
| 151 | genericOp.getIndexingMapsArray(); |
| 152 | |
| 153 | /// Compute the loop ranges for operation. This is the shape of the result of |
| 154 | /// the generic op for the peeled operation. |
| 155 | Location loc = genericOp.getLoc(); |
| 156 | SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp); |
| 157 | SmallVector<Value> newInitValues; |
| 158 | SmallVector<Type> newResultTypes; |
| 159 | |
| 160 | // Add as many new results as the number of results of the peeled scalar op. |
| 161 | for (auto scalarOpResult : peeledScalarOperation->getResults()) { |
| 162 | // If the result is yielded by the original op, use the operand, indexing |
| 163 | // map and result type that correspond to the yielded value. |
| 164 | |
| 165 | std::optional<unsigned> resultNumber; |
| 166 | for (auto *user : scalarOpResult.getUsers()) { |
| 167 | if (auto yieldOp = dyn_cast<YieldOp>(user)) { |
| 168 | // Find the first use of the `scalarOpResult` in the yield op. |
| 169 | for (OpOperand &yieldOperand : yieldOp->getOpOperands()) { |
| 170 | if (yieldOperand.get() == scalarOpResult) { |
| 171 | resultNumber = yieldOperand.getOperandNumber(); |
| 172 | break; |
| 173 | } |
| 174 | } |
| 175 | assert(resultNumber && "unable to find use of a value in its user" ); |
| 176 | break; |
| 177 | } |
| 178 | } |
| 179 | if (resultNumber) { |
| 180 | newInitValues.push_back( |
| 181 | genericOp.getDpsInitOperand(*resultNumber)->get()); |
| 182 | OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber)); |
| 183 | newResultTypes.push_back(result.getType()); |
| 184 | peeledGenericOpIndexingMaps.push_back( |
| 185 | genericOp.getIndexingMapMatchingResult(result)); |
| 186 | continue; |
| 187 | } |
| 188 | |
| 189 | // Fall back path, use an `init_tensor` and identity indexing map. |
| 190 | AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size()); |
| 191 | Value emptyTensor = |
| 192 | rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType()); |
| 193 | newInitValues.push_back(emptyTensor); |
| 194 | newResultTypes.push_back(emptyTensor.getType()); |
| 195 | peeledGenericOpIndexingMaps.push_back(indexingMap); |
| 196 | } |
| 197 | |
| 198 | /// Create the peeled generic op with an empty body. |
| 199 | SmallVector<Value> outsOperands = genericOp.getOutputs(); |
| 200 | outsOperands.append(in_start: newInitValues.begin(), in_end: newInitValues.end()); |
| 201 | SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes()); |
| 202 | resultTypes.append(in_start: newResultTypes.begin(), in_end: newResultTypes.end()); |
| 203 | auto indexingMapAttr = |
| 204 | rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps); |
| 205 | return rewriter.create<GenericOp>( |
| 206 | loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr, |
| 207 | genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr, |
| 208 | [](OpBuilder, Location, ValueRange) {}); |
| 209 | } |
| 210 | |
| 211 | GenericOp |
| 212 | DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, |
| 213 | GenericOp peeledGenericOp, |
| 214 | PatternRewriter &rewriter) const { |
| 215 | /// Append all results from the peeledGenericOps as `ins` operand for the |
| 216 | /// residual generic op. |
| 217 | SmallVector<Value> residualGenericOpOperands = genericOp.getInputs(); |
| 218 | unsigned origNumResults = genericOp.getNumResults(); |
| 219 | unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults(); |
| 220 | SmallVector<Value> ; |
| 221 | for (auto resultNum : |
| 222 | llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) |
| 223 | extraIns.push_back(peeledGenericOp->getResult(resultNum)); |
| 224 | residualGenericOpOperands.append(RHS: extraIns); |
| 225 | |
| 226 | /// Add indexing maps for the newly added operands. Use the same map |
| 227 | /// as those used for the new results of the peeledGenericOp. |
| 228 | auto indexingMaps = llvm::to_vector( |
| 229 | llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) { |
| 230 | return genericOp.getMatchingIndexingMap(operand); |
| 231 | })); |
| 232 | for (auto resultNum : |
| 233 | llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) { |
| 234 | OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum)); |
| 235 | indexingMaps.push_back( |
| 236 | peeledGenericOp.getIndexingMapMatchingResult(result)); |
| 237 | } |
| 238 | for (OpOperand &outOperand : genericOp.getDpsInitsMutable()) |
| 239 | indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand)); |
| 240 | |
| 241 | auto indexingMapAttr = rewriter.getAffineMapArrayAttr(values: indexingMaps); |
| 242 | return rewriter.create<GenericOp>( |
| 243 | genericOp->getLoc(), genericOp->getResultTypes(), |
| 244 | residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr, |
| 245 | genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr, |
| 246 | [](OpBuilder, Location, ValueRange) {}); |
| 247 | } |
| 248 | |
| 249 | LogicalResult |
| 250 | DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, |
| 251 | PatternRewriter &rewriter) const { |
| 252 | /// For now only match on operations where the iterator types are all parallel |
| 253 | if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) { |
| 254 | return rewriter.notifyMatchFailure(genericOp, |
| 255 | "unhandled decomposition of operation " |
| 256 | "with non-parallel iterator types" ); |
| 257 | } |
| 258 | // TODO: this could be generalized to handle `linalg.generic` with buffer |
| 259 | // operands too but requires allocation for intermediates. Punt on this for |
| 260 | // now. |
| 261 | if (!genericOp.hasPureTensorSemantics()) { |
| 262 | return rewriter.notifyMatchFailure( |
| 263 | genericOp, "only operations with tensor semantics are handled" ); |
| 264 | } |
| 265 | |
| 266 | if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) { |
| 267 | return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation(); |
| 268 | })) { |
| 269 | return rewriter.notifyMatchFailure( |
| 270 | genericOp, "unhandled decomposition of generic op with out operand not " |
| 271 | "accessed using a permutation" ); |
| 272 | } |
| 273 | |
| 274 | /// If the op has only a single statement (apart from the yield), do nothing. |
| 275 | Block *body = genericOp.getBody(); |
| 276 | if (body->getOperations().size() <= 2) { |
| 277 | return rewriter.notifyMatchFailure(genericOp, |
| 278 | "operation has less than 3 statements" ); |
| 279 | } |
| 280 | |
| 281 | /// Check that the peeled statement has a scalar element type. |
| 282 | if (llvm::any_of(Range: body->getOperations().begin()->getResultTypes(), |
| 283 | P: [](Type t) { return !t.isIntOrIndexOrFloat(); })) { |
| 284 | return rewriter.notifyMatchFailure( |
| 285 | arg: &(*body->getOperations().begin()), |
| 286 | msg: "expected return type to be only int, index or float" ); |
| 287 | } |
| 288 | |
| 289 | GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter); |
| 290 | GenericOp residualGenericOp = |
| 291 | createResidualGenericOp(genericOp, peeledGenericOp, rewriter); |
| 292 | |
| 293 | /// Move the first statement of the original operation into the body of the |
| 294 | /// generic op for the peeled operation. |
| 295 | Block *peeledGenericOpBody = peeledGenericOp.getBody(); |
| 296 | Block *residualGenericOpBody = residualGenericOp.getBody(); |
| 297 | assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() && |
| 298 | "expected split generic ops to have empty region" ); |
| 299 | peeledGenericOpBody->getOperations().splice( |
| 300 | where: peeledGenericOpBody->begin(), L2&: body->getOperations(), first: body->begin()); |
| 301 | residualGenericOpBody->getOperations().splice(where: residualGenericOpBody->begin(), |
| 302 | L2&: body->getOperations()); |
| 303 | |
| 304 | Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin()); |
| 305 | auto *yieldOp = residualGenericOpBody->getTerminator(); |
| 306 | { |
| 307 | // Yield all the result of the peeled scalar operation. |
| 308 | OpBuilder::InsertionGuard g(rewriter); |
| 309 | rewriter.setInsertionPointToEnd(peeledGenericOpBody); |
| 310 | SmallVector<Value> yieldedVals; |
| 311 | for (auto origYield : yieldOp->getOperands()) { |
| 312 | if (origYield.getDefiningOp() == peeledScalarOperation) { |
| 313 | yieldedVals.push_back(origYield); |
| 314 | } else { |
| 315 | // Do not materialize any new ops inside of the decomposed LinalgOp, |
| 316 | // as that would trigger another application of the rewrite pattern |
| 317 | // (infinite loop). |
| 318 | OpBuilder::InsertionGuard g(rewriter); |
| 319 | rewriter.setInsertionPoint(peeledGenericOp); |
| 320 | yieldedVals.push_back( |
| 321 | getZero(rewriter, genericOp.getLoc(), origYield.getType())); |
| 322 | } |
| 323 | } |
| 324 | yieldedVals.append(RHS: llvm::to_vector( |
| 325 | Range: llvm::map_range(C: peeledScalarOperation->getResults(), |
| 326 | F: [](OpResult opr) -> Value { return opr; }))); |
| 327 | rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals); |
| 328 | } |
| 329 | |
| 330 | /// In the split operations, replace block arguments uses that refer to |
| 331 | /// original operation to the block arguments of the newly created operation. |
| 332 | unsigned origNumInputs = genericOp.getNumDpsInputs(); |
| 333 | for (const auto &inputBlockArg : |
| 334 | llvm::enumerate(genericOp.getBody()->getArguments())) { |
| 335 | Value residualOpReplacementArg = |
| 336 | residualGenericOpBody->getArgument(inputBlockArg.index()); |
| 337 | rewriter.replaceUsesWithIf( |
| 338 | inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) { |
| 339 | return use.getOwner()->getBlock() == residualGenericOpBody; |
| 340 | }); |
| 341 | |
| 342 | Value peeledOpReplacementArg = |
| 343 | peeledGenericOpBody->getArgument(inputBlockArg.index()); |
| 344 | rewriter.replaceUsesWithIf( |
| 345 | inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) { |
| 346 | return use.getOwner()->getBlock() == peeledGenericOpBody; |
| 347 | }); |
| 348 | } |
| 349 | |
| 350 | /// Before fixing up the residual operation, track what values are yielded. If |
| 351 | /// any of those are from the peeled scalar operation, the uses of the |
| 352 | /// corresponding result have to be remapped to result of the generic op for |
| 353 | /// the peeled operation. |
| 354 | SmallVector<Value> replacements; |
| 355 | for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) { |
| 356 | OpResult opr = dyn_cast<OpResult>(yieldValue.value()); |
| 357 | if (!opr || opr.getOwner() != peeledScalarOperation) |
| 358 | replacements.push_back(residualGenericOp.getResult(yieldValue.index())); |
| 359 | else |
| 360 | replacements.push_back(peeledGenericOp->getResult(yieldValue.index())); |
| 361 | } |
| 362 | |
| 363 | /// Update all uses of the peeled scalar operation results in the residual op |
| 364 | /// to the newly added arguments. |
| 365 | { |
| 366 | SmallVector<Value> scalarReplacements; |
| 367 | unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults(); |
| 368 | scalarReplacements.reserve(N: peeledScalarOpNumResults); |
| 369 | for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults)) |
| 370 | scalarReplacements.push_back( |
| 371 | residualGenericOpBody->getArgument(num + origNumInputs)); |
| 372 | bool allUsesReplaced = false; |
| 373 | rewriter.replaceOpUsesWithinBlock(op: peeledScalarOperation, newValues: scalarReplacements, |
| 374 | block: residualGenericOpBody, allUsesReplaced: &allUsesReplaced); |
| 375 | assert(!allUsesReplaced && |
| 376 | "peeled scalar operation is erased when it wasnt expected to be" ); |
| 377 | } |
| 378 | |
| 379 | // Replace the original operation |
| 380 | rewriter.replaceOp(genericOp, replacements); |
| 381 | return success(); |
| 382 | } |
| 383 | |
| 384 | void mlir::linalg::populateDecomposeLinalgOpsPattern( |
| 385 | RewritePatternSet &patterns, bool removeDeadArgsAndResults) { |
| 386 | patterns.insert<DecomposeLinalgOp>(arg: patterns.getContext()); |
| 387 | // Add the patterns to clean up the dead operands and results. |
| 388 | if (removeDeadArgsAndResults) |
| 389 | populateEraseUnusedOperandsAndResultsPatterns(patterns); |
| 390 | } |
| 391 | |