1 | //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
10 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
11 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
12 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
13 | #include "mlir/IR/PatternMatch.h" |
14 | |
15 | namespace mlir { |
16 | namespace tensor { |
17 | namespace { |
18 | |
19 | static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) { |
20 | return llvm::all_of( |
21 | Range&: ofrs, P: [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); }); |
22 | } |
23 | |
24 | /// Returns the number of shape sizes that is either dynamic or greater than 1. |
25 | static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) { |
26 | return llvm::count_if( |
27 | Range&: shape, P: [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; }); |
28 | } |
29 | |
30 | /// Returns success() if there is only 1 dimension size in non-packed domain |
31 | /// being greater than 1 and packing only happens on the dimension. |
32 | /// Note: this method should only be used by pack/unpack to reshape conversion. |
33 | /// It assumes that non-unit inner tile size must be used by the non-unit |
34 | /// dimension. |
35 | static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op, |
36 | ArrayRef<int64_t> srcShape, |
37 | ArrayRef<int64_t> innerPackTileSize) { |
38 | if (getNumGtOneDims(shape: srcShape) > 1) { |
39 | return rewriter.notifyMatchFailure( |
40 | arg&: op, msg: "expects non-packed domain to have at most one non-unit dims" ); |
41 | } |
42 | // Non-unit inner tile size must be used by the non-unit dimension. If not, it |
43 | // will faill on getting reassociation maps. |
44 | if (getNumGtOneDims(shape: innerPackTileSize) > 1) { |
45 | return rewriter.notifyMatchFailure( |
46 | arg&: op, msg: "expects at most one non-unit inner tiles" ); |
47 | } |
48 | return success(); |
49 | } |
50 | |
51 | /// Packing one-dimensional tensor can be expressed as an expand shape op. |
52 | struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> { |
53 | using OpRewritePattern<PackOp>::OpRewritePattern; |
54 | |
55 | Value insertExpand(RewriterBase &rewriter, Location loc, Value operand, |
56 | Type newOperandType, ArrayAttr reassociation) const { |
57 | if (operand.getType() == newOperandType) |
58 | return operand; |
59 | return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand, |
60 | reassociation); |
61 | } |
62 | |
63 | /// Returns success() if it is only packing on the innermost dimension. |
64 | LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter, |
65 | PackOp packOp) const { |
66 | auto outerDimsPerm = packOp.getOuterDimsPerm(); |
67 | if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { |
68 | return rewriter.notifyMatchFailure( |
69 | packOp, |
70 | "expects outer_dims_perm is empty or an identity permutation" ); |
71 | } |
72 | |
73 | int64_t srcRank = packOp.getSourceRank(); |
74 | ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos(); |
75 | if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) { |
76 | return rewriter.notifyMatchFailure( |
77 | packOp, "expects packing at the innermost dimension" ); |
78 | } |
79 | return success(); |
80 | } |
81 | |
82 | LogicalResult matchAndRewrite(PackOp packOp, |
83 | PatternRewriter &rewriter) const override { |
84 | if (packOp.getPaddingValue()) |
85 | return rewriter.notifyMatchFailure(packOp, "expects no padding value" ); |
86 | |
87 | RankedTensorType sourceType = packOp.getSourceType(); |
88 | if (failed(isPackOnInnerMostDim(rewriter, packOp)) && |
89 | failed(isPackOn1D(rewriter, packOp, sourceType.getShape(), |
90 | packOp.getStaticTiles()))) { |
91 | return failure(); |
92 | } |
93 | |
94 | RankedTensorType destType = packOp.getDestType(); |
95 | auto reassociation = |
96 | getReassociationIndicesForReshape(sourceType, destType); |
97 | if (!reassociation) |
98 | return failure(); |
99 | Value expanded = insertExpand( |
100 | rewriter, loc: packOp.getLoc(), operand: packOp.getSource(), newOperandType: destType, |
101 | reassociation: getReassociationIndicesAttribute(rewriter, *reassociation)); |
102 | rewriter.replaceOp(packOp, expanded); |
103 | return success(); |
104 | } |
105 | }; |
106 | |
107 | struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> { |
108 | using OpRewritePattern<UnPackOp>::OpRewritePattern; |
109 | |
110 | Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand, |
111 | Type newOperandType, ArrayAttr reassociation) const { |
112 | if (operand.getType() == newOperandType) |
113 | return operand; |
114 | return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType, |
115 | operand, reassociation); |
116 | } |
117 | |
118 | /// Returns success() if it is unpacking on the innermost dimension. |
119 | LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter, |
120 | UnPackOp unpackOp) const { |
121 | auto outerDimsPerm = unpackOp.getOuterDimsPerm(); |
122 | if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { |
123 | return rewriter.notifyMatchFailure( |
124 | unpackOp, |
125 | "expects outer_dims_perm is empty or an identity permutation" ); |
126 | } |
127 | |
128 | RankedTensorType sourceType = unpackOp.getSourceType(); |
129 | RankedTensorType destType = unpackOp.getDestType(); |
130 | if (!sourceType.hasStaticShape() || !destType.hasStaticShape()) |
131 | return rewriter.notifyMatchFailure(unpackOp, "expects static shapes" ); |
132 | |
133 | ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos(); |
134 | if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) { |
135 | return rewriter.notifyMatchFailure( |
136 | unpackOp, "expects unpacking on the innermost dimension" ); |
137 | } |
138 | |
139 | return success(); |
140 | } |
141 | |
142 | LogicalResult matchAndRewrite(UnPackOp unpackOp, |
143 | PatternRewriter &rewriter) const override { |
144 | RankedTensorType destType = unpackOp.getDestType(); |
145 | if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) && |
146 | failed(isPackOn1D(rewriter, unpackOp, destType.getShape(), |
147 | unpackOp.getStaticTiles()))) { |
148 | return failure(); |
149 | } |
150 | |
151 | RankedTensorType sourceType = unpackOp.getSourceType(); |
152 | auto reassociation = |
153 | getReassociationIndicesForReshape(sourceType, destType); |
154 | if (!reassociation) |
155 | return failure(); |
156 | Value collapsed = insertCollapse( |
157 | rewriter, loc: unpackOp.getLoc(), operand: unpackOp.getSource(), newOperandType: destType, |
158 | reassociation: getReassociationIndicesAttribute(rewriter, *reassociation)); |
159 | rewriter.replaceOp(unpackOp, collapsed); |
160 | return success(); |
161 | } |
162 | }; |
163 | |
164 | /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and |
165 | /// the pad op has zero low paddings, or if `pack` has no padding values. |
166 | struct FoldPadWithPackOp : public OpRewritePattern<PackOp> { |
167 | using OpRewritePattern<PackOp>::OpRewritePattern; |
168 | |
169 | LogicalResult matchAndRewrite(PackOp packOp, |
170 | PatternRewriter &rewriter) const override { |
171 | auto padOp = packOp.getSource().getDefiningOp<PadOp>(); |
172 | |
173 | if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) |
174 | return failure(); |
175 | |
176 | Value constantPaddingValue = padOp.getConstantPaddingValue(); |
177 | if (!constantPaddingValue) |
178 | return failure(); |
179 | |
180 | if (auto paddingValue = packOp.getPaddingValue()) |
181 | if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue)) |
182 | return failure(); |
183 | |
184 | rewriter.replaceOpWithNewOp<PackOp>( |
185 | packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(), |
186 | packOp.getMixedTiles(), constantPaddingValue, |
187 | packOp.getOuterDimsPerm()); |
188 | return success(); |
189 | } |
190 | }; |
191 | |
192 | /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already |
193 | /// has extract_slice semantics. |
194 | struct : public OpRewritePattern<ExtractSliceOp> { |
195 | using OpRewritePattern<ExtractSliceOp>::OpRewritePattern; |
196 | |
197 | LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, |
198 | PatternRewriter &rewriter) const override { |
199 | auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>(); |
200 | if (!unpackOp) |
201 | return failure(); |
202 | |
203 | if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) { |
204 | return rewriter.notifyMatchFailure( |
205 | sliceOp, "rank-reduced folding is not supported" ); |
206 | } |
207 | |
208 | // Check all offsets are zeros, and all strides are ones. |
209 | if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) || |
210 | !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) { |
211 | return rewriter.notifyMatchFailure( |
212 | sliceOp, "expects offsets to be 0s and strides to be 1s" ); |
213 | } |
214 | |
215 | // Create a new empty output tensor. |
216 | Type elementType = unpackOp.getDestType().getElementType(); |
217 | Value output = rewriter.create<EmptyOp>( |
218 | sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType); |
219 | rewriter.replaceOpWithNewOp<UnPackOp>( |
220 | sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(), |
221 | unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm()); |
222 | return success(); |
223 | } |
224 | }; |
225 | |
226 | // Applies 'permutation' on 'inVec' and stores the result in resVec. |
227 | // 'inVec' may be empty, in that case it's one-to-one mapping with permutation. |
228 | // `rank` sets the boundary for permutation i.e., the permutation dim can't be |
229 | // greater than the rank specified. If it's so then return false. |
230 | // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in |
231 | // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is |
232 | // not allowed since `3` exceeds the value of the rank in the given range. |
233 | static bool checkAndPermute(ArrayRef<int64_t> permutation, |
234 | ArrayRef<int64_t> inVec, |
235 | SmallVectorImpl<int64_t> &resVec, int64_t rank) { |
236 | |
237 | for (unsigned int i = 0; i < rank; ++i) { |
238 | int64_t remappedPosition = permutation[i]; |
239 | |
240 | if (!inVec.empty()) { |
241 | if (remappedPosition >= rank) { |
242 | return false; |
243 | } |
244 | remappedPosition = inVec[remappedPosition]; |
245 | } |
246 | |
247 | resVec.push_back(Elt: remappedPosition); |
248 | } |
249 | |
250 | return true; |
251 | } |
252 | |
253 | /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose |
254 | /// semantics. |
255 | struct FoldProducerPackWithConsumerLinalgTransposeOp |
256 | : public OpRewritePattern<linalg::TransposeOp> { |
257 | using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern; |
258 | |
259 | LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, |
260 | PatternRewriter &rewriter) const override { |
261 | auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>(); |
262 | |
263 | if (!packOp) |
264 | return failure(); |
265 | |
266 | auto innerDimsPos = packOp.getInnerDimsPos(); |
267 | auto mixedInnerTiles = packOp.getMixedTiles(); |
268 | auto outerDimsPerm = packOp.getOuterDimsPerm(); |
269 | auto transposePerm = transposeOp.getPermutation(); |
270 | SmallVector<int64_t> newOuterDimsPermVec; |
271 | SmallVector<int64_t> newInnerDimsPosVec; |
272 | SmallVector<OpFoldResult> newMixedInnerTilesVec; |
273 | int64_t srcRank = packOp.getSourceRank(); |
274 | |
275 | if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec, |
276 | srcRank)) |
277 | return rewriter.notifyMatchFailure( |
278 | transposeOp, |
279 | "Cannot fold in tensor.pack if a tile dimension was transposed " |
280 | "with a non-tile dimension in linalg.transpose." ); |
281 | |
282 | // Process transpose operation for tiled inner dimensions |
283 | for (unsigned int i = srcRank; i < transposePerm.size(); ++i) { |
284 | int64_t remappedPosition = transposePerm[i] - srcRank; |
285 | newMixedInnerTilesVec.push_back(Elt: mixedInnerTiles[remappedPosition]); |
286 | newInnerDimsPosVec.push_back(Elt: innerDimsPos[remappedPosition]); |
287 | } |
288 | |
289 | Value output = packOp.createDestinationTensor( |
290 | rewriter, transposeOp.getLoc(), packOp.getSource(), |
291 | newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec); |
292 | |
293 | rewriter.replaceOpWithNewOp<PackOp>( |
294 | transposeOp, packOp.getSource(), output, newInnerDimsPosVec, |
295 | newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec); |
296 | |
297 | return success(); |
298 | } |
299 | }; |
300 | |
301 | /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose |
302 | /// semantics. |
303 | struct FoldConsumerPackWithProducerLinalgTransposeOp |
304 | : public OpRewritePattern<PackOp> { |
305 | using OpRewritePattern<PackOp>::OpRewritePattern; |
306 | |
307 | LogicalResult matchAndRewrite(PackOp packOp, |
308 | PatternRewriter &rewriter) const override { |
309 | auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>(); |
310 | |
311 | if (!transposeOp) |
312 | return failure(); |
313 | |
314 | auto transposePermutation = transposeOp.getPermutation(); |
315 | auto outerDimsPerm = packOp.getOuterDimsPerm(); |
316 | auto innerDimsPos = packOp.getInnerDimsPos(); |
317 | SmallVector<int64_t> newInnerDimsPosVec; |
318 | SmallVector<int64_t> newOuterDimsPermVec = |
319 | llvm::to_vector(transposePermutation); |
320 | |
321 | if (!outerDimsPerm.empty()) |
322 | applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm); |
323 | |
324 | // Can't use applyPermutationToVector for newInnerDimsPosVec since input and |
325 | // permutation rank won't necessarily be equal in all cases. |
326 | for (auto dim : innerDimsPos) |
327 | newInnerDimsPosVec.push_back(transposePermutation[dim]); |
328 | |
329 | Value output = packOp.createDestinationTensor( |
330 | rewriter, packOp.getLoc(), transposeOp.getOperand(0), |
331 | packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec); |
332 | |
333 | rewriter.replaceOpWithNewOp<PackOp>( |
334 | packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec, |
335 | packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec); |
336 | |
337 | return success(); |
338 | } |
339 | }; |
340 | |
341 | /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has |
342 | /// transpose semantics. |
343 | struct FoldProducerUnPackWithConsumerLinalgTransposeOp |
344 | : public OpRewritePattern<linalg::TransposeOp> { |
345 | using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern; |
346 | |
347 | LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, |
348 | PatternRewriter &rewriter) const override { |
349 | auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>(); |
350 | |
351 | if (!unPackOp) |
352 | return failure(); |
353 | |
354 | auto transposePermutation = transposeOp.getPermutation(); |
355 | auto outerDimsPerm = unPackOp.getOuterDimsPerm(); |
356 | auto innerDimsPos = unPackOp.getInnerDimsPos(); |
357 | SmallVector<int64_t> newInnerDimsPosVec; |
358 | SmallVector<int64_t> newOuterDimsPermVec = |
359 | llvm::to_vector(transposePermutation); |
360 | |
361 | if (!outerDimsPerm.empty()) |
362 | applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm); |
363 | |
364 | // Can't use applyPermutationToVector for newInnerDimsPosVec since input and |
365 | // permutation rank won't necessarily be equal in all cases. |
366 | for (auto dim : innerDimsPos) |
367 | newInnerDimsPosVec.push_back(transposePermutation[dim]); |
368 | |
369 | Value output = unPackOp.createDestinationTensor( |
370 | rewriter, transposeOp.getLoc(), unPackOp.getSource(), |
371 | unPackOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec); |
372 | |
373 | rewriter.replaceOpWithNewOp<UnPackOp>( |
374 | transposeOp, unPackOp.getSource(), output, newInnerDimsPosVec, |
375 | unPackOp.getMixedTiles(), newOuterDimsPermVec); |
376 | |
377 | return success(); |
378 | } |
379 | }; |
380 | |
381 | /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has |
382 | /// transpose semantics. |
383 | struct FoldConsumerUnPackWithProducerLinalgTransposeOp |
384 | : public OpRewritePattern<UnPackOp> { |
385 | using OpRewritePattern<UnPackOp>::OpRewritePattern; |
386 | |
387 | LogicalResult matchAndRewrite(UnPackOp unPackOp, |
388 | PatternRewriter &rewriter) const override { |
389 | auto transposeOp = |
390 | unPackOp.getSource().getDefiningOp<linalg::TransposeOp>(); |
391 | |
392 | if (!transposeOp) |
393 | return failure(); |
394 | |
395 | auto transposePermutation = transposeOp.getPermutation(); |
396 | auto outerDimsPerm = unPackOp.getOuterDimsPerm(); |
397 | auto innerDimsPos = unPackOp.getInnerDimsPos(); |
398 | int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size(); |
399 | auto mixedInnerTilesVec = unPackOp.getMixedTiles(); |
400 | SmallVector<int64_t> newOuterDimsPermVec; |
401 | SmallVector<int64_t> newInnerDimsPosVec; |
402 | SmallVector<OpFoldResult> newMixedInnerTilesVec; |
403 | |
404 | if (!checkAndPermute(transposePermutation, outerDimsPerm, |
405 | newOuterDimsPermVec, destRank)) |
406 | return rewriter.notifyMatchFailure( |
407 | unPackOp, |
408 | "Cannot fold in tensor.unpack if a tile dimension was transposed " |
409 | "with a non-tile dimension in linalg.transpose." ); |
410 | |
411 | // Process transpose operation for tiled inner dimensions |
412 | for (unsigned int i = destRank; i < transposePermutation.size(); ++i) { |
413 | int64_t remappedPosition = transposePermutation[i] - destRank; |
414 | newMixedInnerTilesVec.push_back(Elt: mixedInnerTilesVec[remappedPosition]); |
415 | newInnerDimsPosVec.push_back(Elt: innerDimsPos[remappedPosition]); |
416 | } |
417 | |
418 | Value output = unPackOp.createDestinationTensor( |
419 | rewriter, unPackOp.getLoc(), transposeOp.getOperand(0), |
420 | newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec); |
421 | |
422 | rewriter.replaceOpWithNewOp<UnPackOp>( |
423 | unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec, |
424 | newMixedInnerTilesVec, newOuterDimsPermVec); |
425 | |
426 | return success(); |
427 | } |
428 | }; |
429 | } // namespace |
430 | |
431 | void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) { |
432 | patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp, |
433 | FoldProducerPackWithConsumerLinalgTransposeOp, |
434 | FoldConsumerPackWithProducerLinalgTransposeOp, |
435 | FoldConsumerUnPackWithProducerLinalgTransposeOp, |
436 | FoldProducerUnPackWithConsumerLinalgTransposeOp>( |
437 | arg: patterns.getContext()); |
438 | } |
439 | |
440 | void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) { |
441 | patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>( |
442 | arg: patterns.getContext()); |
443 | } |
444 | |
445 | } // namespace tensor |
446 | } // namespace mlir |
447 | |