1//===- DecomposeGenericByUnfoldingPermutation.cpp -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Linalg/IR/Linalg.h"
11#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
12#include <map>
13#include <optional>
14#include <utility>
15
16using namespace mlir;
17using namespace mlir::linalg;
18
19namespace {
20
21/// This pattern decomposes the input operand(s) of a linalg.generic that has
22/// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose
23/// and broadcast. Having them folded into the linalg.generic is a good
24/// optimization but sometimes we may want to unwrap, i.e., `unfold` them as
25/// explicit transpose and broadcast. This rewrite pattern helps do it for
26/// each input operand. This is useful for instance when trying to recognize
27/// named ops.
28///
29/// The transpose, broadcast, or mixture of both, are expressed in the affine
30/// map of the operand. Technically it is essentially `projected permutation`.
31///
32/// Example
33///
34/// ```mlir
35///
36/// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
37/// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
38/// ...
39/// %res = linalg.generic
40/// { indexing_maps = [#projection, #identity, #identity],
41/// iterator_types = ["parallel", "parallel", "parallel",
42/// "parallel", "parallel"]}
43/// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
44/// outs(%z : tensor<5x9x7x8x10xf32>) {
45/// ^bb0(%in: f32, %in_1: f32, %out: f32):
46/// %div = arith.divf %in, %in_1 : f32
47/// linalg.yield %div : f32
48/// } -> tensor<5x9x7x8x10xf32>
49/// ```
50///
51/// In the above IR operand `%x` map is a projected-permutation. This can be
52/// unfolded as:
53///
54/// ```mlir
55/// ...
56/// %x_trans = linalg.transpose
57/// ins(%x : tensor<7x8x9xf32>)
58/// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
59/// ...
60/// %x_trans_bc = linalg.broadcast
61/// ins(%x_trans : tensor<9x7x8xf32>)
62/// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
63/// %2 = linalg.div
64/// ins(%x_trans_bc, %y :
65/// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
66/// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
67///
68/// Note that linalg.generic has been 'specialized' to linalg.div.
69///
70/// To unfold it, it is more optimal to transpose first and then do the
71/// broadcast. However, if transpose is done first, the permutation map needs
72/// to be expressed in terms of reduced dimension as broadcast hasn't happened
73/// yet. Also, the broadcast dimensions in a linalg.generic come from other
74/// operands (those not broadcasted along that particular dimension). We work
75/// this out by computing the convex-polyhedron shape of the linalg.generic
76/// iteration space from shapes of all the operands, both inputs and outputs.
77///
78struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> {
79 using OpRewritePattern<GenericOp>::OpRewritePattern;
80
81 LogicalResult matchAndRewrite(GenericOp genericOp,
82 PatternRewriter &rewriter) const override;
83};
84
85/// For the given `map`, determine what dimensions are transposed and what
86/// dimensions are broadcasted.
87/// Returns :
88/// transpose-permutation, broadcast-dimensions` (empty if not needed)
89///
90std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
91computeTransposeBroadcast(AffineMap &map) {
92 assert(map.isProjectedPermutation(false) && "not a projection");
93
94 // As the map is a projection it likely operates on a smaller set of
95 // dimensions as far as the transpose is concerned (rest are broadcast).
96 int64_t minorSize = map.getNumResults();
97
98 SmallVector<int64_t> minorResult;
99 for (int64_t i = 0; i < minorSize; ++i) {
100 auto expr = cast<AffineDimExpr>(Val: map.getResults()[i]);
101 minorResult.push_back(Elt: expr.getPosition());
102 }
103
104 // If dims are not monotonically increasing then transpose is present.
105 SmallVector<int64_t> sortedResMap(minorResult);
106 llvm::sort(C&: sortedResMap);
107 bool hasTranspose = !std::equal(first1: minorResult.begin(), last1: minorResult.end(),
108 first2: sortedResMap.begin(), last2: sortedResMap.end());
109
110 // Walk the sorted map result to determine which dimensions are broadcasted.
111 SmallVector<int64_t> broadcast;
112 for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) {
113 if (j < minorSize && sortedResMap[j] == i) {
114 j++;
115 continue;
116 }
117 broadcast.push_back(Elt: i);
118 }
119
120 SmallVector<int64_t> permutation;
121 if (hasTranspose) {
122 // Consider an operand `x : tensor<7x8x9>` of a genericOp that has
123 // affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
124 // `x`s access is both transposed and broadcast. But when specifying
125 // the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
126 // specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
127 // refering to d3, d4. Therefore, re-base the transpose dimensions so
128 // that they start from d0.
129 permutation.resize(N: minorSize);
130 std::map<int64_t, int64_t> minorMap;
131 for (int64_t i = 0; i < minorSize; ++i)
132 minorMap.insert(x: {sortedResMap[i], i});
133
134 // Re-map the dimensions.
135 SmallVector<int64_t> remappedResult(minorSize);
136 for (int64_t i = 0; i < minorSize; ++i)
137 remappedResult[i] = minorMap[minorResult[i]];
138
139 /// Calculate the permutation for the transpose.
140 for (unsigned i = 0; i < minorSize; ++i) {
141 permutation[remappedResult[i]] = i;
142 }
143 }
144 return {permutation, broadcast};
145}
146
147LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
148 GenericOp op, PatternRewriter &rewriter) const {
149 if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
150 op.isSingleYieldOp() || !op.isAllParallelLoops())
151 return failure();
152
153 // If the map of an operand is not a `projected permutation` then
154 // it cannot be decomposed to mere transpose and broadcast.
155 // The requirement that all maps be `projected permutation` may be
156 // over-restrictive but since we need to determine shape of the
157 // iteration space as well, reject if any map violates assumption.
158 for (auto &opOperand : op->getOpOperands()) {
159 auto map = op.getMatchingIndexingMap(&opOperand);
160 if (!map.isProjectedPermutation(false))
161 return failure();
162 }
163
164 // Decomposing linalg.generic involves creating `tensor.empty`
165 // which can have dynamic shapes but then we would have to work
166 // out which operand can supply that runtime-value (tensor.dim).
167 // Leaving it as a future TODO.
168 if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) {
169 auto opType = cast<RankedTensorType>(oper.get().getType());
170 return ShapedType::isDynamicShape(opType.getShape());
171 }))
172 return failure();
173
174 auto outputShape = op.getStaticLoopRanges();
175
176 auto loc = op.getLoc();
177 bool isChanged = false;
178 SmallVector<Value> newInitValues = op.getDpsInputs();
179 SmallVector<AffineMap> newMap = op.getIndexingMapsArray();
180
181 // Walk over each input operand and unfold if it is transposed, broadcast
182 // or mix of two via operand's affine-map.
183 for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
184 auto &map = newMap[i];
185 auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType());
186 auto elType = inputRTType.getElementType();
187
188 /// Nothing to do if map is already an identity.
189 if (map.isIdentity())
190 continue;
191
192 auto [permutation, broadcastedDims] = computeTransposeBroadcast(map);
193
194 // Does it need transpose?
195 if (!permutation.empty()) {
196 /// linalg.transpose permutes the dimensions of input using
197 /// rule: dim(result, i) = dim(input, permutation[i])
198 SmallVector<int64_t> transposedShape(map.getNumResults());
199 for (int64_t i = 0; i < map.getNumResults(); ++i)
200 transposedShape[i] = inputRTType.getShape()[permutation[i]];
201
202 Value emptyTensor =
203 rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType);
204
205 auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i],
206 emptyTensor, permutation);
207 newInitValues[i] = transposeOp->getResult(0);
208 isChanged = true;
209 }
210
211 // Does it require broadcast?
212 if (!broadcastedDims.empty()) {
213 assert(broadcastedDims.size() && "should have non size broadcast");
214 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
215 loc, outputShape, inputRTType.getElementType());
216
217 auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
218 loc, newInitValues[i], emptyTensor, broadcastedDims);
219
220 newInitValues[i] = broadcastOp->getResult(0);
221 isChanged = true;
222 }
223 newMap[i] = rewriter.getMultiDimIdentityMap(rank: map.getNumDims());
224 }
225
226 if (!isChanged)
227 return failure();
228
229 SmallVector<Value> operands = op->getOperands();
230 ValueRange operandsRef(operands);
231
232 auto newOp = rewriter.create<linalg::GenericOp>(
233 /*location=*/op.getLoc(),
234 /*resultTensorTypes=*/op->getResultTypes(),
235 /*inputs=*/newInitValues,
236 /*outputs=*/operandsRef.drop_front(n: op.getNumDpsInputs()),
237 /*indexingMaps=*/newMap,
238 /*iteratorTypes=*/op.getIteratorTypesArray());
239 newOp.getRegion().takeBody(op->getRegion(0));
240 rewriter.replaceOp(op, newOp->getResults());
241 return success();
242}
243
244} // namespace
245
246void mlir::linalg::populateDecomposeProjectedPermutationPatterns(
247 RewritePatternSet &patterns) {
248 patterns.insert<DecomposeProjectedPermutation>(arg: patterns.getContext());
249}
250

source code of mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp