1//===- DecomposeGenericByUnfoldingPermutation.cpp -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9#include "mlir/Dialect/Linalg/IR/Linalg.h"
10#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11#include <map>
12#include <utility>
13
14using namespace mlir;
15using namespace mlir::linalg;
16
17namespace {
18
19/// This pattern decomposes the input operand(s) of a linalg.generic that has
20/// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose
21/// and broadcast. Having them folded into the linalg.generic is a good
22/// optimization but sometimes we may want to unwrap, i.e., `unfold` them as
23/// explicit transpose and broadcast. This rewrite pattern helps do it for
24/// each input operand. This is useful for instance when trying to recognize
25/// named ops.
26///
27/// The transpose, broadcast, or mixture of both, are expressed in the affine
28/// map of the operand. Technically it is essentially `projected permutation`.
29///
30/// Example
31///
32/// ```mlir
33///
34/// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
35/// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
36/// ...
37/// %res = linalg.generic
38/// { indexing_maps = [#projection, #identity, #identity],
39/// iterator_types = ["parallel", "parallel", "parallel",
40/// "parallel", "parallel"]}
41/// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
42/// outs(%z : tensor<5x9x7x8x10xf32>) {
43/// ^bb0(%in: f32, %in_1: f32, %out: f32):
44/// %div = arith.divf %in, %in_1 : f32
45/// linalg.yield %div : f32
46/// } -> tensor<5x9x7x8x10xf32>
47/// ```
48///
49/// In the above IR operand `%x` map is a projected-permutation. This can be
50/// unfolded as:
51///
52/// ```mlir
53/// ...
54/// %x_trans = linalg.transpose
55/// ins(%x : tensor<7x8x9xf32>)
56/// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
57/// ...
58/// %x_trans_bc = linalg.broadcast
59/// ins(%x_trans : tensor<9x7x8xf32>)
60/// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
61/// %2 = linalg.div
62/// ins(%x_trans_bc, %y :
63/// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
64/// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
65///
66/// Note that linalg.generic has been 'specialized' to linalg.div.
67///
68/// To unfold it, it is more optimal to transpose first and then do the
69/// broadcast. However, if transpose is done first, the permutation map needs
70/// to be expressed in terms of reduced dimension as broadcast hasn't happened
71/// yet. Also, the broadcast dimensions in a linalg.generic come from other
72/// operands (those not broadcasted along that particular dimension). We work
73/// this out by computing the convex-polyhedron shape of the linalg.generic
74/// iteration space from shapes of all the operands, both inputs and outputs.
75///
76struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> {
77 using OpRewritePattern<GenericOp>::OpRewritePattern;
78
79 LogicalResult matchAndRewrite(GenericOp genericOp,
80 PatternRewriter &rewriter) const override;
81};
82
83/// For the given `map`, determine what dimensions are transposed and what
84/// dimensions are broadcasted.
85/// Returns :
86/// transpose-permutation, broadcast-dimensions` (empty if not needed)
87///
88std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
89computeTransposeBroadcast(AffineMap &map) {
90 assert(map.isProjectedPermutation(false) && "not a projection");
91
92 // As the map is a projection it likely operates on a smaller set of
93 // dimensions as far as the transpose is concerned (rest are broadcast).
94 int64_t minorSize = map.getNumResults();
95
96 SmallVector<int64_t> minorResult;
97 for (int64_t i = 0; i < minorSize; ++i) {
98 auto expr = cast<AffineDimExpr>(Val: map.getResults()[i]);
99 minorResult.push_back(Elt: expr.getPosition());
100 }
101
102 // If dims are not monotonically increasing then transpose is present.
103 SmallVector<int64_t> sortedResMap(minorResult);
104 llvm::sort(C&: sortedResMap);
105 bool hasTranspose = !std::equal(first1: minorResult.begin(), last1: minorResult.end(),
106 first2: sortedResMap.begin(), last2: sortedResMap.end());
107
108 // Walk the sorted map result to determine which dimensions are broadcasted.
109 SmallVector<int64_t> broadcast;
110 for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) {
111 if (j < minorSize && sortedResMap[j] == i) {
112 j++;
113 continue;
114 }
115 broadcast.push_back(Elt: i);
116 }
117
118 SmallVector<int64_t> permutation;
119 if (hasTranspose) {
120 // Consider an operand `x : tensor<7x8x9>` of a genericOp that has
121 // affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
122 // `x`s access is both transposed and broadcast. But when specifying
123 // the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
124 // specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
125 // refering to d3, d4. Therefore, re-base the transpose dimensions so
126 // that they start from d0.
127 permutation.resize(N: minorSize);
128 std::map<int64_t, int64_t> minorMap;
129 for (int64_t i = 0; i < minorSize; ++i)
130 minorMap.insert(x: {sortedResMap[i], i});
131
132 // Re-map the dimensions.
133 SmallVector<int64_t> remappedResult(minorSize);
134 for (int64_t i = 0; i < minorSize; ++i)
135 remappedResult[i] = minorMap[minorResult[i]];
136
137 /// Calculate the permutation for the transpose.
138 for (unsigned i = 0; i < minorSize; ++i) {
139 permutation[remappedResult[i]] = i;
140 }
141 }
142 return {permutation, broadcast};
143}
144
145LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
146 GenericOp op, PatternRewriter &rewriter) const {
147 if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
148 op.isSingleYieldOp() || !op.isAllParallelLoops())
149 return failure();
150
151 // If the map of an operand is not a `projected permutation` then
152 // it cannot be decomposed to mere transpose and broadcast.
153 // The requirement that all maps be `projected permutation` may be
154 // over-restrictive but since we need to determine shape of the
155 // iteration space as well, reject if any map violates assumption.
156 for (auto &opOperand : op->getOpOperands()) {
157 auto map = op.getMatchingIndexingMap(opOperand: &opOperand);
158 if (!map.isProjectedPermutation(allowZeroInResults: false))
159 return failure();
160 }
161
162 // Decomposing linalg.generic involves creating `tensor.empty`
163 // which can have dynamic shapes but then we would have to work
164 // out which operand can supply that runtime-value (tensor.dim).
165 // Leaving it as a future TODO.
166 if (llvm::any_of(Range: op->getOpOperands(), P: [](OpOperand &oper) {
167 auto opType = cast<RankedTensorType>(Val: oper.get().getType());
168 return ShapedType::isDynamicShape(dSizes: opType.getShape());
169 }))
170 return failure();
171
172 auto outputShape = op.getStaticLoopRanges();
173
174 auto loc = op.getLoc();
175 bool isChanged = false;
176 SmallVector<Value> newInitValues = op.getDpsInputs();
177 SmallVector<AffineMap> newMap = op.getIndexingMapsArray();
178
179 // Walk over each input operand and unfold if it is transposed, broadcast
180 // or mix of two via operand's affine-map.
181 for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
182 auto &map = newMap[i];
183 auto inputRTType = cast<RankedTensorType>(Val: newInitValues[i].getType());
184 auto elType = inputRTType.getElementType();
185
186 /// Nothing to do if map is already an identity.
187 if (map.isIdentity())
188 continue;
189
190 auto [permutation, broadcastedDims] = computeTransposeBroadcast(map);
191
192 // Does it need transpose?
193 if (!permutation.empty()) {
194 /// linalg.transpose permutes the dimensions of input using
195 /// rule: dim(result, i) = dim(input, permutation[i])
196 SmallVector<int64_t> transposedShape(map.getNumResults());
197 for (int64_t i = 0; i < map.getNumResults(); ++i)
198 transposedShape[i] = inputRTType.getShape()[permutation[i]];
199
200 Value emptyTensor =
201 rewriter.create<tensor::EmptyOp>(location: loc, args&: transposedShape, args&: elType);
202
203 auto transposeOp = rewriter.create<TransposeOp>(location: loc, args&: newInitValues[i],
204 args&: emptyTensor, args&: permutation);
205 newInitValues[i] = transposeOp->getResult(idx: 0);
206 isChanged = true;
207 }
208
209 // Does it require broadcast?
210 if (!broadcastedDims.empty()) {
211 assert(broadcastedDims.size() && "should have non size broadcast");
212 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
213 location: loc, args&: outputShape, args: inputRTType.getElementType());
214
215 auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
216 location: loc, args&: newInitValues[i], args&: emptyTensor, args&: broadcastedDims);
217
218 newInitValues[i] = broadcastOp->getResult(idx: 0);
219 isChanged = true;
220 }
221 newMap[i] = rewriter.getMultiDimIdentityMap(rank: map.getNumDims());
222 }
223
224 if (!isChanged)
225 return failure();
226
227 SmallVector<Value> operands = op->getOperands();
228 ValueRange operandsRef(operands);
229
230 auto newOp = rewriter.create<linalg::GenericOp>(
231 /*location=*/op.getLoc(),
232 /*resultTensorTypes=*/args: op->getResultTypes(),
233 /*inputs=*/args&: newInitValues,
234 /*outputs=*/args: operandsRef.drop_front(n: op.getNumDpsInputs()),
235 /*indexingMaps=*/args&: newMap,
236 /*iteratorTypes=*/args: op.getIteratorTypesArray());
237 newOp.getRegion().takeBody(other&: op->getRegion(index: 0));
238 rewriter.replaceOp(op, newValues: newOp->getResults());
239 return success();
240}
241
242} // namespace
243
244void mlir::linalg::populateDecomposeProjectedPermutationPatterns(
245 RewritePatternSet &patterns) {
246 patterns.insert<DecomposeProjectedPermutation>(arg: patterns.getContext());
247}
248

source code of mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp