1//===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Linalg/IR/Linalg.h"
13#include <optional>
14
15using namespace mlir;
16using namespace mlir::linalg;
17
18namespace {
19
20/// Pattern to decompose a GenericOp that has more than two statements
21/// into one GenericOp with the first statement (i.e. peeled operation), and
22/// a second GenericOp with the remaining statements (i.e. residual operations).
23
24/// - The result of the first GenericOp has the same shape as the iteration
25/// space of the GenericOp. The body of the op yields as many values as the
26/// original op plus all the results of the peeled operation.
27/// - The second GenericOp has as many operands as the original operation plus
28/// all the results of the first Generic Op. It has the same number of yields as
29/// the original op.
30/// - If the result of the peeled operation was yielded by the original
31/// GenericOp the uses of the corresponding results will be replaced with the
32/// result of the first GenericOp created.
33///
34/// Example
35///
36/// ```mlir
37/// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
38/// outs(%init0, %init1 : ...) {
39/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
40/// %0 = <s0> %b0, %b1 : ...
41/// %1 = <s1> %0, %b2 : ...
42/// linalg.yield %0, %1 : ...
43/// } -> (..., ...)
44/// return %result#0, %result#1
45/// ```
46///
47/// gets split into
48///
49/// ```mlir
50/// %init = tensor.empty ...
51/// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
52/// outs(%init0, %init1, %init : ...)
53/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
54/// %0 = <s0> %b0, %b1 : ...
55/// linalg.yield %0, %..., %0 : ...
56/// } -> (..., ..., ...)
57/// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
58/// outs(%init0, %init1 : ...) {
59/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
60/// %1 = <s1> %b3, %b2 : ...
61/// linalg.yield %..., %1 : ...
62/// } -> (..., ...)
63/// return %op0#0, %op1#1
64/// ```
65///
66/// After canonicalization this is expected to be
67///
68/// ```mlir
69/// %init = tensor.empty ...
70/// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
71/// outs(%init : ...)
72/// ^bb0(%b0: ... , %b1: ... , %b2: ...):
73/// %0 = <s0> %b0, %b1 : ...
74/// linalg.yield %0 : ...
75/// } -> ...
76/// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
77/// outs(%init1 : ...) {
78/// ^bb0(%b0: ... , %b1: ... , %b2: ...):
79/// %1 = <s1> %b1, %b0 : ...
80/// linalg.yield %..., %1 : ...
81/// } -> ...
82/// return %op0, %op1
83/// ```
84struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
85 using OpRewritePattern<GenericOp>::OpRewritePattern;
86
87 LogicalResult matchAndRewrite(GenericOp genericOp,
88 PatternRewriter &rewriter) const override;
89
90private:
91 /// Helper method to create a generic op for the peeled scalar operation. The
92 /// created op has an empty region.
93 GenericOp createPeeledGenericOp(GenericOp genericOp,
94 PatternRewriter &rewriter) const;
95
96 /// Helper method to create a generic op for the residual scalar operation.
97 /// The created op has the same region as the original op.
98 GenericOp createResidualGenericOp(GenericOp genericOp,
99 GenericOp peeledGenericOp,
100 PatternRewriter &rewriter) const;
101};
102} // namespace
103
104/// Helper method to compute the range of a generic op.
105static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b,
106 GenericOp op) {
107 OpBuilder::InsertionGuard g(b);
108 b.setInsertionPoint(op);
109 Location loc = op.getLoc();
110 auto allShapesSizes =
111 cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
112 AffineMap map = op.getShapesToLoopsMap();
113 IRRewriter rewriter(b);
114 return affine::makeComposedFoldedMultiResultAffineApply(b&: rewriter, loc, map,
115 operands: allShapesSizes);
116}
117
118/// Helper method to permute the list of `values` based on the `map`.
119SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values,
120 AffineMap map) {
121 assert(map.isPermutation());
122 SmallVector<OpFoldResult> permutedValues(values.size());
123 for (const auto &position :
124 llvm::enumerate(First: llvm::map_range(C: map.getResults(), F: [](AffineExpr expr) {
125 return cast<AffineDimExpr>(Val&: expr).getPosition();
126 })))
127 permutedValues[position.value()] = values[position.index()];
128 return permutedValues;
129}
130
131/// Get zero value for an element type.
132static Value getZero(OpBuilder &b, Location loc, Type elementType) {
133 assert(elementType.isIntOrIndexOrFloat() &&
134 "expected scalar type while computing zero value");
135 if (isa<IntegerType>(Val: elementType))
136 return b.create<arith::ConstantIntOp>(location: loc, args: 0, args&: elementType);
137 if (elementType.isIndex())
138 return b.create<arith::ConstantIndexOp>(location: loc, args: 0);
139 // Assume float.
140 auto floatType = cast<FloatType>(Val&: elementType);
141 return b.create<arith::ConstantFloatOp>(
142 location: loc, args: APFloat::getZero(Sem: floatType.getFloatSemantics()), args&: floatType);
143}
144
145GenericOp
146DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
147 PatternRewriter &rewriter) const {
148 Block *body = genericOp.getBody();
149 Operation *peeledScalarOperation = &(*body->begin());
150 SmallVector<AffineMap> peeledGenericOpIndexingMaps =
151 genericOp.getIndexingMapsArray();
152
153 /// Compute the loop ranges for operation. This is the shape of the result of
154 /// the generic op for the peeled operation.
155 Location loc = genericOp.getLoc();
156 SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
157 SmallVector<Value> newInitValues;
158 SmallVector<Type> newResultTypes;
159
160 // Add as many new results as the number of results of the peeled scalar op.
161 for (auto scalarOpResult : peeledScalarOperation->getResults()) {
162 // If the result is yielded by the original op, use the operand, indexing
163 // map and result type that correspond to the yielded value.
164
165 std::optional<unsigned> resultNumber;
166 for (auto *user : scalarOpResult.getUsers()) {
167 if (auto yieldOp = dyn_cast<YieldOp>(user)) {
168 // Find the first use of the `scalarOpResult` in the yield op.
169 for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
170 if (yieldOperand.get() == scalarOpResult) {
171 resultNumber = yieldOperand.getOperandNumber();
172 break;
173 }
174 }
175 assert(resultNumber && "unable to find use of a value in its user");
176 break;
177 }
178 }
179 if (resultNumber) {
180 newInitValues.push_back(
181 genericOp.getDpsInitOperand(*resultNumber)->get());
182 OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber));
183 newResultTypes.push_back(result.getType());
184 peeledGenericOpIndexingMaps.push_back(
185 genericOp.getIndexingMapMatchingResult(result));
186 continue;
187 }
188
189 // Fall back path, use an `init_tensor` and identity indexing map.
190 AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
191 Value emptyTensor =
192 rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType());
193 newInitValues.push_back(emptyTensor);
194 newResultTypes.push_back(emptyTensor.getType());
195 peeledGenericOpIndexingMaps.push_back(indexingMap);
196 }
197
198 /// Create the peeled generic op with an empty body.
199 SmallVector<Value> outsOperands = genericOp.getOutputs();
200 outsOperands.append(in_start: newInitValues.begin(), in_end: newInitValues.end());
201 SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
202 resultTypes.append(in_start: newResultTypes.begin(), in_end: newResultTypes.end());
203 auto indexingMapAttr =
204 rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
205 return rewriter.create<GenericOp>(
206 loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,
207 genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
208 [](OpBuilder, Location, ValueRange) {});
209}
210
211GenericOp
212DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
213 GenericOp peeledGenericOp,
214 PatternRewriter &rewriter) const {
215 /// Append all results from the peeledGenericOps as `ins` operand for the
216 /// residual generic op.
217 SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
218 unsigned origNumResults = genericOp.getNumResults();
219 unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
220 SmallVector<Value> extraIns;
221 for (auto resultNum :
222 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
223 extraIns.push_back(peeledGenericOp->getResult(resultNum));
224 residualGenericOpOperands.append(RHS: extraIns);
225
226 /// Add indexing maps for the newly added operands. Use the same map
227 /// as those used for the new results of the peeledGenericOp.
228 auto indexingMaps = llvm::to_vector(
229 llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
230 return genericOp.getMatchingIndexingMap(operand);
231 }));
232 for (auto resultNum :
233 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
234 OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
235 indexingMaps.push_back(
236 peeledGenericOp.getIndexingMapMatchingResult(result));
237 }
238 for (OpOperand &outOperand : genericOp.getDpsInitsMutable())
239 indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
240
241 auto indexingMapAttr = rewriter.getAffineMapArrayAttr(values: indexingMaps);
242 return rewriter.create<GenericOp>(
243 genericOp->getLoc(), genericOp->getResultTypes(),
244 residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
245 genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
246 [](OpBuilder, Location, ValueRange) {});
247}
248
249LogicalResult
250DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
251 PatternRewriter &rewriter) const {
252 /// For now only match on operations where the iterator types are all parallel
253 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
254 return rewriter.notifyMatchFailure(genericOp,
255 "unhandled decomposition of operation "
256 "with non-parallel iterator types");
257 }
258 // TODO: this could be generalized to handle `linalg.generic` with buffer
259 // operands too but requires allocation for intermediates. Punt on this for
260 // now.
261 if (!genericOp.hasPureTensorSemantics()) {
262 return rewriter.notifyMatchFailure(
263 genericOp, "only operations with tensor semantics are handled");
264 }
265
266 if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {
267 return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
268 })) {
269 return rewriter.notifyMatchFailure(
270 genericOp, "unhandled decomposition of generic op with out operand not "
271 "accessed using a permutation");
272 }
273
274 /// If the op has only a single statement (apart from the yield), do nothing.
275 Block *body = genericOp.getBody();
276 if (body->getOperations().size() <= 2) {
277 return rewriter.notifyMatchFailure(genericOp,
278 "operation has less than 3 statements");
279 }
280
281 /// Check that the peeled statement has a scalar element type.
282 if (llvm::any_of(Range: body->getOperations().begin()->getResultTypes(),
283 P: [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
284 return rewriter.notifyMatchFailure(
285 arg: &(*body->getOperations().begin()),
286 msg: "expected return type to be only int, index or float");
287 }
288
289 GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
290 GenericOp residualGenericOp =
291 createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
292
293 /// Move the first statement of the original operation into the body of the
294 /// generic op for the peeled operation.
295 Block *peeledGenericOpBody = peeledGenericOp.getBody();
296 Block *residualGenericOpBody = residualGenericOp.getBody();
297 assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
298 "expected split generic ops to have empty region");
299 peeledGenericOpBody->getOperations().splice(
300 where: peeledGenericOpBody->begin(), L2&: body->getOperations(), first: body->begin());
301 residualGenericOpBody->getOperations().splice(where: residualGenericOpBody->begin(),
302 L2&: body->getOperations());
303
304 Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
305 auto *yieldOp = residualGenericOpBody->getTerminator();
306 {
307 // Yield all the result of the peeled scalar operation.
308 OpBuilder::InsertionGuard g(rewriter);
309 rewriter.setInsertionPointToEnd(peeledGenericOpBody);
310 SmallVector<Value> yieldedVals;
311 for (auto origYield : yieldOp->getOperands()) {
312 if (origYield.getDefiningOp() == peeledScalarOperation) {
313 yieldedVals.push_back(origYield);
314 } else {
315 // Do not materialize any new ops inside of the decomposed LinalgOp,
316 // as that would trigger another application of the rewrite pattern
317 // (infinite loop).
318 OpBuilder::InsertionGuard g(rewriter);
319 rewriter.setInsertionPoint(peeledGenericOp);
320 yieldedVals.push_back(
321 getZero(rewriter, genericOp.getLoc(), origYield.getType()));
322 }
323 }
324 yieldedVals.append(RHS: llvm::to_vector(
325 Range: llvm::map_range(C: peeledScalarOperation->getResults(),
326 F: [](OpResult opr) -> Value { return opr; })));
327 rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
328 }
329
330 /// In the split operations, replace block arguments uses that refer to
331 /// original operation to the block arguments of the newly created operation.
332 unsigned origNumInputs = genericOp.getNumDpsInputs();
333 for (const auto &inputBlockArg :
334 llvm::enumerate(genericOp.getBody()->getArguments())) {
335 Value residualOpReplacementArg =
336 residualGenericOpBody->getArgument(inputBlockArg.index());
337 rewriter.replaceUsesWithIf(
338 inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {
339 return use.getOwner()->getBlock() == residualGenericOpBody;
340 });
341
342 Value peeledOpReplacementArg =
343 peeledGenericOpBody->getArgument(inputBlockArg.index());
344 rewriter.replaceUsesWithIf(
345 inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {
346 return use.getOwner()->getBlock() == peeledGenericOpBody;
347 });
348 }
349
350 /// Before fixing up the residual operation, track what values are yielded. If
351 /// any of those are from the peeled scalar operation, the uses of the
352 /// corresponding result have to be remapped to result of the generic op for
353 /// the peeled operation.
354 SmallVector<Value> replacements;
355 for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
356 OpResult opr = dyn_cast<OpResult>(yieldValue.value());
357 if (!opr || opr.getOwner() != peeledScalarOperation)
358 replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
359 else
360 replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
361 }
362
363 /// Update all uses of the peeled scalar operation results in the residual op
364 /// to the newly added arguments.
365 {
366 SmallVector<Value> scalarReplacements;
367 unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
368 scalarReplacements.reserve(N: peeledScalarOpNumResults);
369 for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
370 scalarReplacements.push_back(
371 residualGenericOpBody->getArgument(num + origNumInputs));
372 bool allUsesReplaced = false;
373 rewriter.replaceOpUsesWithinBlock(op: peeledScalarOperation, newValues: scalarReplacements,
374 block: residualGenericOpBody, allUsesReplaced: &allUsesReplaced);
375 assert(!allUsesReplaced &&
376 "peeled scalar operation is erased when it wasnt expected to be");
377 }
378
379 // Replace the original operation
380 rewriter.replaceOp(genericOp, replacements);
381 return success();
382}
383
384void mlir::linalg::populateDecomposeLinalgOpsPattern(
385 RewritePatternSet &patterns, bool removeDeadArgsAndResults) {
386 patterns.insert<DecomposeLinalgOp>(arg: patterns.getContext());
387 // Add the patterns to clean up the dead operands and results.
388 if (removeDeadArgsAndResults)
389 populateEraseUnusedOperandsAndResultsPatterns(patterns);
390}
391

source code of mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp