1 | //===- DecomposeGenericByUnfoldingPermutation.cpp -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
10 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
11 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
12 | #include <map> |
13 | #include <optional> |
14 | #include <utility> |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::linalg; |
18 | |
19 | namespace { |
20 | |
21 | /// This pattern decomposes the input operand(s) of a linalg.generic that has |
22 | /// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose |
23 | /// and broadcast. Having them folded into the linalg.generic is a good |
24 | /// optimization but sometimes we may want to unwrap, i.e., `unfold` them as |
25 | /// explicit transpose and broadcast. This rewrite pattern helps do it for |
26 | /// each input operand. This is useful for instance when trying to recognize |
27 | /// named ops. |
28 | /// |
29 | /// The transpose, broadcast, or mixture of both, are expressed in the affine |
30 | /// map of the operand. Technically it is essentially `projected permutation`. |
31 | /// |
32 | /// Example |
33 | /// |
34 | /// ```mlir |
35 | /// |
36 | /// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)> |
37 | /// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> |
38 | /// ... |
39 | /// %res = linalg.generic |
40 | /// { indexing_maps = [#projection, #identity, #identity], |
41 | /// iterator_types = ["parallel", "parallel", "parallel", |
42 | /// "parallel", "parallel"]} |
43 | /// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) |
44 | /// outs(%z : tensor<5x9x7x8x10xf32>) { |
45 | /// ^bb0(%in: f32, %in_1: f32, %out: f32): |
46 | /// %div = arith.divf %in, %in_1 : f32 |
47 | /// linalg.yield %div : f32 |
48 | /// } -> tensor<5x9x7x8x10xf32> |
49 | /// ``` |
50 | /// |
51 | /// In the above IR operand `%x` map is a projected-permutation. This can be |
52 | /// unfolded as: |
53 | /// |
54 | /// ```mlir |
55 | /// ... |
56 | /// %x_trans = linalg.transpose |
57 | /// ins(%x : tensor<7x8x9xf32>) |
58 | /// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1] |
59 | /// ... |
60 | /// %x_trans_bc = linalg.broadcast |
61 | /// ins(%x_trans : tensor<9x7x8xf32>) |
62 | /// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] |
63 | /// %2 = linalg.div |
64 | /// ins(%x_trans_bc, %y : |
65 | /// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) |
66 | /// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> |
67 | /// |
68 | /// Note that linalg.generic has been 'specialized' to linalg.div. |
69 | /// |
70 | /// To unfold it, it is more optimal to transpose first and then do the |
71 | /// broadcast. However, if transpose is done first, the permutation map needs |
72 | /// to be expressed in terms of reduced dimension as broadcast hasn't happened |
73 | /// yet. Also, the broadcast dimensions in a linalg.generic come from other |
74 | /// operands (those not broadcasted along that particular dimension). We work |
75 | /// this out by computing the convex-polyhedron shape of the linalg.generic |
76 | /// iteration space from shapes of all the operands, both inputs and outputs. |
77 | /// |
78 | struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> { |
79 | using OpRewritePattern<GenericOp>::OpRewritePattern; |
80 | |
81 | LogicalResult matchAndRewrite(GenericOp genericOp, |
82 | PatternRewriter &rewriter) const override; |
83 | }; |
84 | |
85 | /// For the given `map`, determine what dimensions are transposed and what |
86 | /// dimensions are broadcasted. |
87 | /// Returns : |
88 | /// transpose-permutation, broadcast-dimensions` (empty if not needed) |
89 | /// |
90 | std::pair<SmallVector<int64_t>, SmallVector<int64_t>> |
91 | computeTransposeBroadcast(AffineMap &map) { |
92 | assert(map.isProjectedPermutation(false) && "not a projection" ); |
93 | |
94 | // As the map is a projection it likely operates on a smaller set of |
95 | // dimensions as far as the transpose is concerned (rest are broadcast). |
96 | int64_t minorSize = map.getNumResults(); |
97 | |
98 | SmallVector<int64_t> minorResult; |
99 | for (int64_t i = 0; i < minorSize; ++i) { |
100 | auto expr = cast<AffineDimExpr>(Val: map.getResults()[i]); |
101 | minorResult.push_back(Elt: expr.getPosition()); |
102 | } |
103 | |
104 | // If dims are not monotonically increasing then transpose is present. |
105 | SmallVector<int64_t> sortedResMap(minorResult); |
106 | llvm::sort(C&: sortedResMap); |
107 | bool hasTranspose = !std::equal(first1: minorResult.begin(), last1: minorResult.end(), |
108 | first2: sortedResMap.begin(), last2: sortedResMap.end()); |
109 | |
110 | // Walk the sorted map result to determine which dimensions are broadcasted. |
111 | SmallVector<int64_t> broadcast; |
112 | for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) { |
113 | if (j < minorSize && sortedResMap[j] == i) { |
114 | j++; |
115 | continue; |
116 | } |
117 | broadcast.push_back(Elt: i); |
118 | } |
119 | |
120 | SmallVector<int64_t> permutation; |
121 | if (hasTranspose) { |
122 | // Consider an operand `x : tensor<7x8x9>` of a genericOp that has |
123 | // affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>` |
124 | // `x`s access is both transposed and broadcast. But when specifying |
125 | // the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be |
126 | // specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of |
127 | // refering to d3, d4. Therefore, re-base the transpose dimensions so |
128 | // that they start from d0. |
129 | permutation.resize(N: minorSize); |
130 | std::map<int64_t, int64_t> minorMap; |
131 | for (int64_t i = 0; i < minorSize; ++i) |
132 | minorMap.insert(x: {sortedResMap[i], i}); |
133 | |
134 | // Re-map the dimensions. |
135 | SmallVector<int64_t> remappedResult(minorSize); |
136 | for (int64_t i = 0; i < minorSize; ++i) |
137 | remappedResult[i] = minorMap[minorResult[i]]; |
138 | |
139 | /// Calculate the permutation for the transpose. |
140 | for (unsigned i = 0; i < minorSize; ++i) { |
141 | permutation[remappedResult[i]] = i; |
142 | } |
143 | } |
144 | return {permutation, broadcast}; |
145 | } |
146 | |
147 | LogicalResult DecomposeProjectedPermutation::matchAndRewrite( |
148 | GenericOp op, PatternRewriter &rewriter) const { |
149 | if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() || |
150 | op.isSingleYieldOp() || !op.isAllParallelLoops()) |
151 | return failure(); |
152 | |
153 | // If the map of an operand is not a `projected permutation` then |
154 | // it cannot be decomposed to mere transpose and broadcast. |
155 | // The requirement that all maps be `projected permutation` may be |
156 | // over-restrictive but since we need to determine shape of the |
157 | // iteration space as well, reject if any map violates assumption. |
158 | for (auto &opOperand : op->getOpOperands()) { |
159 | auto map = op.getMatchingIndexingMap(&opOperand); |
160 | if (!map.isProjectedPermutation(false)) |
161 | return failure(); |
162 | } |
163 | |
164 | // Decomposing linalg.generic involves creating `tensor.empty` |
165 | // which can have dynamic shapes but then we would have to work |
166 | // out which operand can supply that runtime-value (tensor.dim). |
167 | // Leaving it as a future TODO. |
168 | if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) { |
169 | auto opType = cast<RankedTensorType>(oper.get().getType()); |
170 | return ShapedType::isDynamicShape(opType.getShape()); |
171 | })) |
172 | return failure(); |
173 | |
174 | auto outputShape = op.getStaticLoopRanges(); |
175 | |
176 | auto loc = op.getLoc(); |
177 | bool isChanged = false; |
178 | SmallVector<Value> newInitValues = op.getDpsInputs(); |
179 | SmallVector<AffineMap> newMap = op.getIndexingMapsArray(); |
180 | |
181 | // Walk over each input operand and unfold if it is transposed, broadcast |
182 | // or mix of two via operand's affine-map. |
183 | for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) { |
184 | auto &map = newMap[i]; |
185 | auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType()); |
186 | auto elType = inputRTType.getElementType(); |
187 | |
188 | /// Nothing to do if map is already an identity. |
189 | if (map.isIdentity()) |
190 | continue; |
191 | |
192 | auto [permutation, broadcastedDims] = computeTransposeBroadcast(map); |
193 | |
194 | // Does it need transpose? |
195 | if (!permutation.empty()) { |
196 | /// linalg.transpose permutes the dimensions of input using |
197 | /// rule: dim(result, i) = dim(input, permutation[i]) |
198 | SmallVector<int64_t> transposedShape(map.getNumResults()); |
199 | for (int64_t i = 0; i < map.getNumResults(); ++i) |
200 | transposedShape[i] = inputRTType.getShape()[permutation[i]]; |
201 | |
202 | Value emptyTensor = |
203 | rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType); |
204 | |
205 | auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i], |
206 | emptyTensor, permutation); |
207 | newInitValues[i] = transposeOp->getResult(0); |
208 | isChanged = true; |
209 | } |
210 | |
211 | // Does it require broadcast? |
212 | if (!broadcastedDims.empty()) { |
213 | assert(broadcastedDims.size() && "should have non size broadcast" ); |
214 | Value emptyTensor = rewriter.create<tensor::EmptyOp>( |
215 | loc, outputShape, inputRTType.getElementType()); |
216 | |
217 | auto broadcastOp = rewriter.create<linalg::BroadcastOp>( |
218 | loc, newInitValues[i], emptyTensor, broadcastedDims); |
219 | |
220 | newInitValues[i] = broadcastOp->getResult(0); |
221 | isChanged = true; |
222 | } |
223 | newMap[i] = rewriter.getMultiDimIdentityMap(rank: map.getNumDims()); |
224 | } |
225 | |
226 | if (!isChanged) |
227 | return failure(); |
228 | |
229 | SmallVector<Value> operands = op->getOperands(); |
230 | ValueRange operandsRef(operands); |
231 | |
232 | auto newOp = rewriter.create<linalg::GenericOp>( |
233 | /*location=*/op.getLoc(), |
234 | /*resultTensorTypes=*/op->getResultTypes(), |
235 | /*inputs=*/newInitValues, |
236 | /*outputs=*/operandsRef.drop_front(n: op.getNumDpsInputs()), |
237 | /*indexingMaps=*/newMap, |
238 | /*iteratorTypes=*/op.getIteratorTypesArray()); |
239 | newOp.getRegion().takeBody(op->getRegion(0)); |
240 | rewriter.replaceOp(op, newOp->getResults()); |
241 | return success(); |
242 | } |
243 | |
244 | } // namespace |
245 | |
246 | void mlir::linalg::populateDecomposeProjectedPermutationPatterns( |
247 | RewritePatternSet &patterns) { |
248 | patterns.insert<DecomposeProjectedPermutation>(arg: patterns.getContext()); |
249 | } |
250 | |