1 | //===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns that transforms linalg.<op> + |
10 | // tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce |
11 | // the computation for the linalg op. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
17 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
18 | #include "mlir/Dialect/Linalg/Passes.h" |
19 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
20 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
22 | |
23 | using namespace mlir; |
24 | using namespace mlir::linalg; |
25 | |
26 | namespace { |
27 | /// Bubble up extract_slice above Linalg operation. |
28 | /// |
29 | /// A sequence of operations |
30 | /// |
31 | /// ```mlir |
32 | /// %0 = linalg.<op> ... arg0, arg1, ... |
33 | /// %1 = tensor.extract_slice %0 ... |
34 | /// ``` |
35 | /// |
36 | /// can be replaced with |
37 | /// |
38 | /// ```mlir |
39 | /// %0 = tensor.extract_slice %arg0 |
40 | /// %1 = tensor.extract_slice %arg1 |
41 | /// %2 = linalg.<op> ... %0, %1, ... |
42 | /// ``` |
43 | /// |
44 | /// This results in the reduce computation of the linalg operation. |
45 | /// |
46 | struct |
47 | : OpRewritePattern<tensor::ExtractSliceOp> { |
48 | using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; |
49 | |
50 | LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, |
51 | PatternRewriter &rewriter) const final { |
52 | Value source = sliceOp.getSource(); |
53 | auto linalgOp = source.getDefiningOp<LinalgOp>(); |
54 | if (!linalgOp) { |
55 | return rewriter.notifyMatchFailure(sliceOp, |
56 | "expected source to be linalg op" ); |
57 | } |
58 | |
59 | // TODO: we might relax this if we want heuristics to detect that all uses |
60 | // are small portion of the output. |
61 | if (!linalgOp->hasOneUse()) { |
62 | return rewriter.notifyMatchFailure(sliceOp, |
63 | "expected single use of linalg op" ); |
64 | } |
65 | |
66 | if (linalgOp.getNumDpsInits() != 1) { |
67 | return rewriter.notifyMatchFailure(sliceOp, |
68 | "expected single output of linalg op" ); |
69 | } |
70 | |
71 | if (!linalgOp.hasPureTensorSemantics()) { |
72 | return rewriter.notifyMatchFailure(sliceOp, |
73 | "expected tensor of linalg op" ); |
74 | } |
75 | |
76 | if (!sliceOp.hasUnitStride()) |
77 | return rewriter.notifyMatchFailure(sliceOp, "expected unit stride" ); |
78 | |
79 | if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) { |
80 | return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction" ); |
81 | } |
82 | |
83 | OpOperand *outOperand = linalgOp.getDpsInitOperand(0); |
84 | AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand); |
85 | if (!indexingMap.isProjectedPermutation()) { |
86 | return rewriter.notifyMatchFailure( |
87 | sliceOp, "expected a projected permutation for output" ); |
88 | } |
89 | |
90 | auto linalgLoc = linalgOp.getLoc(); |
91 | SmallVector<OpFoldResult> allShapeSizes = |
92 | linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc); |
93 | AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap(); |
94 | if (!shapeSizesToLoopsMap) { |
95 | return rewriter.notifyMatchFailure( |
96 | linalgOp, "failed to get loops map from shape sizes" ); |
97 | } |
98 | SmallVector<OpFoldResult> sizeBounds = |
99 | affine::makeComposedFoldedMultiResultAffineApply( |
100 | rewriter, linalgLoc, shapeSizesToLoopsMap, allShapeSizes); |
101 | |
102 | // The offsets and sizes from the slice operation only give you the tile |
103 | // size of the output. Use that compute the tile sizes and offsets of the |
104 | // loops. For loops not used to access the output, set the tile sizes to |
105 | // loop bounds and set the offset to 0. |
106 | SmallVector<OpFoldResult> tileOffsets(sizeBounds.size(), |
107 | rewriter.getIndexAttr(0)); |
108 | SmallVector<OpFoldResult> tileSizes = sizeBounds; |
109 | for (auto const &result : enumerate(indexingMap.getResults())) { |
110 | unsigned position = cast<AffineDimExpr>(result.value()).getPosition(); |
111 | tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()]; |
112 | tileSizes[position] = sliceOp.getMixedSizes()[result.index()]; |
113 | } |
114 | |
115 | SmallVector<Value> valuesToTile = linalgOp->getOperands(); |
116 | SmallVector<Value> tiledOperands = |
117 | makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile, |
118 | tileOffsets, tileSizes, sizeBounds, |
119 | /*omitPartialTileCheck=*/true); |
120 | |
121 | SmallVector<Type, 4> resultTensorTypes; |
122 | for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) |
123 | resultTensorTypes.push_back( |
124 | tiledOperands[opOperand.getOperandNumber()].getType()); |
125 | |
126 | Operation *newOp = |
127 | clone(rewriter, linalgOp, resultTensorTypes, tiledOperands); |
128 | rewriter.replaceOp(sliceOp, newOp->getResults()); |
129 | return success(); |
130 | } |
131 | }; |
132 | } // namespace |
133 | |
134 | void mlir::linalg::( |
135 | RewritePatternSet &patterns) { |
136 | auto *context = patterns.getContext(); |
137 | patterns.add<BubbleUpExtractSliceOpPattern>(arg&: context); |
138 | } |
139 | |