1 | //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
13 | #include <optional> |
14 | |
15 | using namespace mlir; |
16 | using namespace mlir::linalg; |
17 | |
18 | namespace { |
19 | |
20 | /// Pattern to decompose a GenericOp that has more than two statements |
21 | /// into one GenericOp with the first statement (i.e. peeled operation), and |
22 | /// a second GenericOp with the remaining statements (i.e. residual operations). |
23 | |
24 | /// - The result of the first GenericOp has the same shape as the iteration |
25 | /// space of the GenericOp. The body of the op yields as many values as the |
26 | /// original op plus all the results of the peeled operation. |
27 | /// - The second GenericOp has as many operands as the original operation plus |
28 | /// all the results of the first Generic Op. It has the same number of yields as |
29 | /// the original op. |
30 | /// - If the result of the peeled operation was yielded by the original |
31 | /// GenericOp the uses of the corresponding results will be replaced with the |
32 | /// result of the first GenericOp created. |
33 | /// |
34 | /// Example |
35 | /// |
36 | /// ```mlir |
37 | /// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) |
38 | /// outs(%init0, %init1 : ...) { |
39 | /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...): |
40 | /// %0 = <s0> %b0, %b1 : ... |
41 | /// %1 = <s1> %0, %b2 : ... |
42 | /// linalg.yield %0, %1 : ... |
43 | /// } -> (..., ...) |
44 | /// return %result#0, %result#1 |
45 | /// ``` |
46 | /// |
47 | /// gets split into |
48 | /// |
49 | /// ```mlir |
50 | /// %init = tensor.empty ... |
51 | /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) |
52 | /// outs(%init0, %init1, %init : ...) |
53 | /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): |
54 | /// %0 = <s0> %b0, %b1 : ... |
55 | /// linalg.yield %0, %..., %0 : ... |
56 | /// } -> (..., ..., ...) |
57 | /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...) |
58 | /// outs(%init0, %init1 : ...) { |
59 | /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): |
60 | /// %1 = <s1> %b3, %b2 : ... |
61 | /// linalg.yield %..., %1 : ... |
62 | /// } -> (..., ...) |
63 | /// return %op0#0, %op1#1 |
64 | /// ``` |
65 | /// |
66 | /// After canonicalization this is expected to be |
67 | /// |
68 | /// ```mlir |
69 | /// %init = tensor.empty ... |
70 | /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...) |
71 | /// outs(%init : ...) |
72 | /// ^bb0(%b0: ... , %b1: ... , %b2: ...): |
73 | /// %0 = <s0> %b0, %b1 : ... |
74 | /// linalg.yield %0 : ... |
75 | /// } -> ... |
76 | /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...) |
77 | /// outs(%init1 : ...) { |
78 | /// ^bb0(%b0: ... , %b1: ... , %b2: ...): |
79 | /// %1 = <s1> %b1, %b0 : ... |
80 | /// linalg.yield %..., %1 : ... |
81 | /// } -> ... |
82 | /// return %op0, %op1 |
83 | /// ``` |
84 | struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> { |
85 | using OpRewritePattern<GenericOp>::OpRewritePattern; |
86 | |
87 | LogicalResult matchAndRewrite(GenericOp genericOp, |
88 | PatternRewriter &rewriter) const override; |
89 | |
90 | private: |
91 | /// Helper method to create a generic op for the peeled scalar operation. The |
92 | /// created op has an empty region. |
93 | GenericOp createPeeledGenericOp(GenericOp genericOp, |
94 | PatternRewriter &rewriter) const; |
95 | |
96 | /// Helper method to create a generic op for the residual scalar operation. |
97 | /// The created op has the same region as the original op. |
98 | GenericOp createResidualGenericOp(GenericOp genericOp, |
99 | GenericOp peeledGenericOp, |
100 | PatternRewriter &rewriter) const; |
101 | }; |
102 | } // namespace |
103 | |
104 | /// Helper method to compute the range of a generic op. |
105 | static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b, |
106 | GenericOp op) { |
107 | OpBuilder::InsertionGuard g(b); |
108 | b.setInsertionPoint(op); |
109 | Location loc = op.getLoc(); |
110 | auto allShapesSizes = |
111 | cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc); |
112 | AffineMap map = op.getShapesToLoopsMap(); |
113 | IRRewriter rewriter(b); |
114 | return affine::makeComposedFoldedMultiResultAffineApply(b&: rewriter, loc, map, |
115 | operands: allShapesSizes); |
116 | } |
117 | |
118 | /// Helper method to permute the list of `values` based on the `map`. |
119 | SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values, |
120 | AffineMap map) { |
121 | assert(map.isPermutation()); |
122 | SmallVector<OpFoldResult> permutedValues(values.size()); |
123 | for (const auto &position : |
124 | llvm::enumerate(First: llvm::map_range(C: map.getResults(), F: [](AffineExpr expr) { |
125 | return cast<AffineDimExpr>(Val&: expr).getPosition(); |
126 | }))) |
127 | permutedValues[position.value()] = values[position.index()]; |
128 | return permutedValues; |
129 | } |
130 | |
131 | /// Get zero value for an element type. |
132 | static Value getZero(OpBuilder &b, Location loc, Type elementType) { |
133 | assert(elementType.isIntOrIndexOrFloat() && |
134 | "expected scalar type while computing zero value" ); |
135 | if (isa<IntegerType>(Val: elementType)) |
136 | return b.create<arith::ConstantIntOp>(location: loc, args: 0, args&: elementType); |
137 | if (elementType.isIndex()) |
138 | return b.create<arith::ConstantIndexOp>(location: loc, args: 0); |
139 | // Assume float. |
140 | auto floatType = cast<FloatType>(Val&: elementType); |
141 | return b.create<arith::ConstantFloatOp>( |
142 | location: loc, args: APFloat::getZero(Sem: floatType.getFloatSemantics()), args&: floatType); |
143 | } |
144 | |
145 | GenericOp |
146 | DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, |
147 | PatternRewriter &rewriter) const { |
148 | Block *body = genericOp.getBody(); |
149 | Operation *peeledScalarOperation = &(*body->begin()); |
150 | SmallVector<AffineMap> peeledGenericOpIndexingMaps = |
151 | genericOp.getIndexingMapsArray(); |
152 | |
153 | /// Compute the loop ranges for operation. This is the shape of the result of |
154 | /// the generic op for the peeled operation. |
155 | Location loc = genericOp.getLoc(); |
156 | SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp); |
157 | SmallVector<Value> newInitValues; |
158 | SmallVector<Type> newResultTypes; |
159 | |
160 | // Add as many new results as the number of results of the peeled scalar op. |
161 | for (auto scalarOpResult : peeledScalarOperation->getResults()) { |
162 | // If the result is yielded by the original op, use the operand, indexing |
163 | // map and result type that correspond to the yielded value. |
164 | |
165 | std::optional<unsigned> resultNumber; |
166 | for (auto *user : scalarOpResult.getUsers()) { |
167 | if (auto yieldOp = dyn_cast<YieldOp>(user)) { |
168 | // Find the first use of the `scalarOpResult` in the yield op. |
169 | for (OpOperand &yieldOperand : yieldOp->getOpOperands()) { |
170 | if (yieldOperand.get() == scalarOpResult) { |
171 | resultNumber = yieldOperand.getOperandNumber(); |
172 | break; |
173 | } |
174 | } |
175 | assert(resultNumber && "unable to find use of a value in its user" ); |
176 | break; |
177 | } |
178 | } |
179 | if (resultNumber) { |
180 | newInitValues.push_back( |
181 | genericOp.getDpsInitOperand(*resultNumber)->get()); |
182 | OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber)); |
183 | newResultTypes.push_back(result.getType()); |
184 | peeledGenericOpIndexingMaps.push_back( |
185 | genericOp.getIndexingMapMatchingResult(result)); |
186 | continue; |
187 | } |
188 | |
189 | // Fall back path, use an `init_tensor` and identity indexing map. |
190 | AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size()); |
191 | Value emptyTensor = |
192 | rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType()); |
193 | newInitValues.push_back(emptyTensor); |
194 | newResultTypes.push_back(emptyTensor.getType()); |
195 | peeledGenericOpIndexingMaps.push_back(indexingMap); |
196 | } |
197 | |
198 | /// Create the peeled generic op with an empty body. |
199 | SmallVector<Value> outsOperands = genericOp.getOutputs(); |
200 | outsOperands.append(in_start: newInitValues.begin(), in_end: newInitValues.end()); |
201 | SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes()); |
202 | resultTypes.append(in_start: newResultTypes.begin(), in_end: newResultTypes.end()); |
203 | auto indexingMapAttr = |
204 | rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps); |
205 | return rewriter.create<GenericOp>( |
206 | loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr, |
207 | genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr, |
208 | [](OpBuilder, Location, ValueRange) {}); |
209 | } |
210 | |
211 | GenericOp |
212 | DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, |
213 | GenericOp peeledGenericOp, |
214 | PatternRewriter &rewriter) const { |
215 | /// Append all results from the peeledGenericOps as `ins` operand for the |
216 | /// residual generic op. |
217 | SmallVector<Value> residualGenericOpOperands = genericOp.getInputs(); |
218 | unsigned origNumResults = genericOp.getNumResults(); |
219 | unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults(); |
220 | SmallVector<Value> ; |
221 | for (auto resultNum : |
222 | llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) |
223 | extraIns.push_back(peeledGenericOp->getResult(resultNum)); |
224 | residualGenericOpOperands.append(RHS: extraIns); |
225 | |
226 | /// Add indexing maps for the newly added operands. Use the same map |
227 | /// as those used for the new results of the peeledGenericOp. |
228 | auto indexingMaps = llvm::to_vector( |
229 | llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) { |
230 | return genericOp.getMatchingIndexingMap(operand); |
231 | })); |
232 | for (auto resultNum : |
233 | llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) { |
234 | OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum)); |
235 | indexingMaps.push_back( |
236 | peeledGenericOp.getIndexingMapMatchingResult(result)); |
237 | } |
238 | for (OpOperand &outOperand : genericOp.getDpsInitsMutable()) |
239 | indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand)); |
240 | |
241 | auto indexingMapAttr = rewriter.getAffineMapArrayAttr(values: indexingMaps); |
242 | return rewriter.create<GenericOp>( |
243 | genericOp->getLoc(), genericOp->getResultTypes(), |
244 | residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr, |
245 | genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr, |
246 | [](OpBuilder, Location, ValueRange) {}); |
247 | } |
248 | |
249 | LogicalResult |
250 | DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, |
251 | PatternRewriter &rewriter) const { |
252 | /// For now only match on operations where the iterator types are all parallel |
253 | if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) { |
254 | return rewriter.notifyMatchFailure(genericOp, |
255 | "unhandled decomposition of operation " |
256 | "with non-parallel iterator types" ); |
257 | } |
258 | // TODO: this could be generalized to handle `linalg.generic` with buffer |
259 | // operands too but requires allocation for intermediates. Punt on this for |
260 | // now. |
261 | if (!genericOp.hasPureTensorSemantics()) { |
262 | return rewriter.notifyMatchFailure( |
263 | genericOp, "only operations with tensor semantics are handled" ); |
264 | } |
265 | |
266 | if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) { |
267 | return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation(); |
268 | })) { |
269 | return rewriter.notifyMatchFailure( |
270 | genericOp, "unhandled decomposition of generic op with out operand not " |
271 | "accessed using a permutation" ); |
272 | } |
273 | |
274 | /// If the op has only a single statement (apart from the yield), do nothing. |
275 | Block *body = genericOp.getBody(); |
276 | if (body->getOperations().size() <= 2) { |
277 | return rewriter.notifyMatchFailure(genericOp, |
278 | "operation has less than 3 statements" ); |
279 | } |
280 | |
281 | /// Check that the peeled statement has a scalar element type. |
282 | if (llvm::any_of(Range: body->getOperations().begin()->getResultTypes(), |
283 | P: [](Type t) { return !t.isIntOrIndexOrFloat(); })) { |
284 | return rewriter.notifyMatchFailure( |
285 | arg: &(*body->getOperations().begin()), |
286 | msg: "expected return type to be only int, index or float" ); |
287 | } |
288 | |
289 | GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter); |
290 | GenericOp residualGenericOp = |
291 | createResidualGenericOp(genericOp, peeledGenericOp, rewriter); |
292 | |
293 | /// Move the first statement of the original operation into the body of the |
294 | /// generic op for the peeled operation. |
295 | Block *peeledGenericOpBody = peeledGenericOp.getBody(); |
296 | Block *residualGenericOpBody = residualGenericOp.getBody(); |
297 | assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() && |
298 | "expected split generic ops to have empty region" ); |
299 | peeledGenericOpBody->getOperations().splice( |
300 | where: peeledGenericOpBody->begin(), L2&: body->getOperations(), first: body->begin()); |
301 | residualGenericOpBody->getOperations().splice(where: residualGenericOpBody->begin(), |
302 | L2&: body->getOperations()); |
303 | |
304 | Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin()); |
305 | auto *yieldOp = residualGenericOpBody->getTerminator(); |
306 | { |
307 | // Yield all the result of the peeled scalar operation. |
308 | OpBuilder::InsertionGuard g(rewriter); |
309 | rewriter.setInsertionPointToEnd(peeledGenericOpBody); |
310 | SmallVector<Value> yieldedVals; |
311 | for (auto origYield : yieldOp->getOperands()) { |
312 | if (origYield.getDefiningOp() == peeledScalarOperation) { |
313 | yieldedVals.push_back(origYield); |
314 | } else { |
315 | // Do not materialize any new ops inside of the decomposed LinalgOp, |
316 | // as that would trigger another application of the rewrite pattern |
317 | // (infinite loop). |
318 | OpBuilder::InsertionGuard g(rewriter); |
319 | rewriter.setInsertionPoint(peeledGenericOp); |
320 | yieldedVals.push_back( |
321 | getZero(rewriter, genericOp.getLoc(), origYield.getType())); |
322 | } |
323 | } |
324 | yieldedVals.append(RHS: llvm::to_vector( |
325 | Range: llvm::map_range(C: peeledScalarOperation->getResults(), |
326 | F: [](OpResult opr) -> Value { return opr; }))); |
327 | rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals); |
328 | } |
329 | |
330 | /// In the split operations, replace block arguments uses that refer to |
331 | /// original operation to the block arguments of the newly created operation. |
332 | unsigned origNumInputs = genericOp.getNumDpsInputs(); |
333 | for (const auto &inputBlockArg : |
334 | llvm::enumerate(genericOp.getBody()->getArguments())) { |
335 | Value residualOpReplacementArg = |
336 | residualGenericOpBody->getArgument(inputBlockArg.index()); |
337 | rewriter.replaceUsesWithIf( |
338 | inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) { |
339 | return use.getOwner()->getBlock() == residualGenericOpBody; |
340 | }); |
341 | |
342 | Value peeledOpReplacementArg = |
343 | peeledGenericOpBody->getArgument(inputBlockArg.index()); |
344 | rewriter.replaceUsesWithIf( |
345 | inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) { |
346 | return use.getOwner()->getBlock() == peeledGenericOpBody; |
347 | }); |
348 | } |
349 | |
350 | /// Before fixing up the residual operation, track what values are yielded. If |
351 | /// any of those are from the peeled scalar operation, the uses of the |
352 | /// corresponding result have to be remapped to result of the generic op for |
353 | /// the peeled operation. |
354 | SmallVector<Value> replacements; |
355 | for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) { |
356 | OpResult opr = dyn_cast<OpResult>(yieldValue.value()); |
357 | if (!opr || opr.getOwner() != peeledScalarOperation) |
358 | replacements.push_back(residualGenericOp.getResult(yieldValue.index())); |
359 | else |
360 | replacements.push_back(peeledGenericOp->getResult(yieldValue.index())); |
361 | } |
362 | |
363 | /// Update all uses of the peeled scalar operation results in the residual op |
364 | /// to the newly added arguments. |
365 | { |
366 | SmallVector<Value> scalarReplacements; |
367 | unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults(); |
368 | scalarReplacements.reserve(N: peeledScalarOpNumResults); |
369 | for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults)) |
370 | scalarReplacements.push_back( |
371 | residualGenericOpBody->getArgument(num + origNumInputs)); |
372 | bool allUsesReplaced = false; |
373 | rewriter.replaceOpUsesWithinBlock(op: peeledScalarOperation, newValues: scalarReplacements, |
374 | block: residualGenericOpBody, allUsesReplaced: &allUsesReplaced); |
375 | assert(!allUsesReplaced && |
376 | "peeled scalar operation is erased when it wasnt expected to be" ); |
377 | } |
378 | |
379 | // Replace the original operation |
380 | rewriter.replaceOp(genericOp, replacements); |
381 | return success(); |
382 | } |
383 | |
384 | void mlir::linalg::populateDecomposeLinalgOpsPattern( |
385 | RewritePatternSet &patterns, bool removeDeadArgsAndResults) { |
386 | patterns.insert<DecomposeLinalgOp>(arg: patterns.getContext()); |
387 | // Add the patterns to clean up the dead operands and results. |
388 | if (removeDeadArgsAndResults) |
389 | populateEraseUnusedOperandsAndResultsPatterns(patterns); |
390 | } |
391 | |