1//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/IR/Linalg.h"
10#include "mlir/Dialect/Tensor/IR/Tensor.h"
11#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
12#include "mlir/Dialect/Utils/IndexingUtils.h"
13#include "mlir/IR/PatternMatch.h"
14
15namespace mlir {
16namespace tensor {
17namespace {
18
19static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
20 return llvm::all_of(
21 Range&: ofrs, P: [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
22}
23
24/// Returns the number of shape sizes that is either dynamic or greater than 1.
25static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
26 return llvm::count_if(
27 Range&: shape, P: [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
28}
29
30/// Returns success() if there is only 1 dimension size in non-packed domain
31/// being greater than 1 and packing only happens on the dimension.
32/// Note: this method should only be used by pack/unpack to reshape conversion.
33/// It assumes that non-unit inner tile size must be used by the non-unit
34/// dimension.
35static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
36 ArrayRef<int64_t> srcShape,
37 ArrayRef<int64_t> innerPackTileSize) {
38 if (getNumGtOneDims(shape: srcShape) > 1) {
39 return rewriter.notifyMatchFailure(
40 arg&: op, msg: "expects non-packed domain to have at most one non-unit dims");
41 }
42 // Non-unit inner tile size must be used by the non-unit dimension. If not, it
43 // will faill on getting reassociation maps.
44 if (getNumGtOneDims(shape: innerPackTileSize) > 1) {
45 return rewriter.notifyMatchFailure(
46 arg&: op, msg: "expects at most one non-unit inner tiles");
47 }
48 return success();
49}
50
51/// Packing one-dimensional tensor can be expressed as an expand shape op.
52struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
53 using OpRewritePattern<PackOp>::OpRewritePattern;
54
55 Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
56 Type newOperandType, ArrayAttr reassociation) const {
57 if (operand.getType() == newOperandType)
58 return operand;
59 return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
60 reassociation);
61 }
62
63 /// Returns success() if it is only packing on the innermost dimension.
64 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
65 PackOp packOp) const {
66 auto outerDimsPerm = packOp.getOuterDimsPerm();
67 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
68 return rewriter.notifyMatchFailure(
69 packOp,
70 "expects outer_dims_perm is empty or an identity permutation");
71 }
72
73 int64_t srcRank = packOp.getSourceRank();
74 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
75 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
76 return rewriter.notifyMatchFailure(
77 packOp, "expects packing at the innermost dimension");
78 }
79 return success();
80 }
81
82 LogicalResult matchAndRewrite(PackOp packOp,
83 PatternRewriter &rewriter) const override {
84 if (packOp.getPaddingValue())
85 return rewriter.notifyMatchFailure(packOp, "expects no padding value");
86
87 RankedTensorType sourceType = packOp.getSourceType();
88 if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
89 failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
90 packOp.getStaticTiles()))) {
91 return failure();
92 }
93
94 RankedTensorType destType = packOp.getDestType();
95 auto reassociation =
96 getReassociationIndicesForReshape(sourceType, destType);
97 if (!reassociation)
98 return failure();
99 Value expanded = insertExpand(
100 rewriter, loc: packOp.getLoc(), operand: packOp.getSource(), newOperandType: destType,
101 reassociation: getReassociationIndicesAttribute(rewriter, *reassociation));
102 rewriter.replaceOp(packOp, expanded);
103 return success();
104 }
105};
106
107struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
108 using OpRewritePattern<UnPackOp>::OpRewritePattern;
109
110 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
111 Type newOperandType, ArrayAttr reassociation) const {
112 if (operand.getType() == newOperandType)
113 return operand;
114 return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
115 operand, reassociation);
116 }
117
118 /// Returns success() if it is unpacking on the innermost dimension.
119 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
120 UnPackOp unpackOp) const {
121 auto outerDimsPerm = unpackOp.getOuterDimsPerm();
122 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
123 return rewriter.notifyMatchFailure(
124 unpackOp,
125 "expects outer_dims_perm is empty or an identity permutation");
126 }
127
128 RankedTensorType sourceType = unpackOp.getSourceType();
129 RankedTensorType destType = unpackOp.getDestType();
130 if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
131 return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
132
133 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
134 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
135 return rewriter.notifyMatchFailure(
136 unpackOp, "expects unpacking on the innermost dimension");
137 }
138
139 return success();
140 }
141
142 LogicalResult matchAndRewrite(UnPackOp unpackOp,
143 PatternRewriter &rewriter) const override {
144 RankedTensorType destType = unpackOp.getDestType();
145 if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
146 failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
147 unpackOp.getStaticTiles()))) {
148 return failure();
149 }
150
151 RankedTensorType sourceType = unpackOp.getSourceType();
152 auto reassociation =
153 getReassociationIndicesForReshape(sourceType, destType);
154 if (!reassociation)
155 return failure();
156 Value collapsed = insertCollapse(
157 rewriter, loc: unpackOp.getLoc(), operand: unpackOp.getSource(), newOperandType: destType,
158 reassociation: getReassociationIndicesAttribute(rewriter, *reassociation));
159 rewriter.replaceOp(unpackOp, collapsed);
160 return success();
161 }
162};
163
164/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
165/// the pad op has zero low paddings, or if `pack` has no padding values.
166struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
167 using OpRewritePattern<PackOp>::OpRewritePattern;
168
169 LogicalResult matchAndRewrite(PackOp packOp,
170 PatternRewriter &rewriter) const override {
171 auto padOp = packOp.getSource().getDefiningOp<PadOp>();
172
173 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
174 return failure();
175
176 Value constantPaddingValue = padOp.getConstantPaddingValue();
177 if (!constantPaddingValue)
178 return failure();
179
180 if (auto paddingValue = packOp.getPaddingValue())
181 if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
182 return failure();
183
184 rewriter.replaceOpWithNewOp<PackOp>(
185 packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
186 packOp.getMixedTiles(), constantPaddingValue,
187 packOp.getOuterDimsPerm());
188 return success();
189 }
190};
191
192/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
193/// has extract_slice semantics.
194struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
195 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
196
197 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
198 PatternRewriter &rewriter) const override {
199 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
200 if (!unpackOp)
201 return failure();
202
203 if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
204 return rewriter.notifyMatchFailure(
205 sliceOp, "rank-reduced folding is not supported");
206 }
207
208 // Check all offsets are zeros, and all strides are ones.
209 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
210 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
211 return rewriter.notifyMatchFailure(
212 sliceOp, "expects offsets to be 0s and strides to be 1s");
213 }
214
215 // Create a new empty output tensor.
216 Type elementType = unpackOp.getDestType().getElementType();
217 Value output = rewriter.create<EmptyOp>(
218 sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
219 rewriter.replaceOpWithNewOp<UnPackOp>(
220 sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
221 unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
222 return success();
223 }
224};
225
226// Applies 'permutation' on 'inVec' and stores the result in resVec.
227// 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
228// `rank` sets the boundary for permutation i.e., the permutation dim can't be
229// greater than the rank specified. If it's so then return false.
230// For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
231// permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
232// not allowed since `3` exceeds the value of the rank in the given range.
233static bool checkAndPermute(ArrayRef<int64_t> permutation,
234 ArrayRef<int64_t> inVec,
235 SmallVectorImpl<int64_t> &resVec, int64_t rank) {
236
237 for (unsigned int i = 0; i < rank; ++i) {
238 int64_t remappedPosition = permutation[i];
239
240 if (!inVec.empty()) {
241 if (remappedPosition >= rank) {
242 return false;
243 }
244 remappedPosition = inVec[remappedPosition];
245 }
246
247 resVec.push_back(Elt: remappedPosition);
248 }
249
250 return true;
251}
252
253/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
254/// semantics.
255struct FoldProducerPackWithConsumerLinalgTransposeOp
256 : public OpRewritePattern<linalg::TransposeOp> {
257 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
258
259 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
260 PatternRewriter &rewriter) const override {
261 auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
262
263 if (!packOp)
264 return failure();
265
266 auto innerDimsPos = packOp.getInnerDimsPos();
267 auto mixedInnerTiles = packOp.getMixedTiles();
268 auto outerDimsPerm = packOp.getOuterDimsPerm();
269 auto transposePerm = transposeOp.getPermutation();
270 SmallVector<int64_t> newOuterDimsPermVec;
271 SmallVector<int64_t> newInnerDimsPosVec;
272 SmallVector<OpFoldResult> newMixedInnerTilesVec;
273 int64_t srcRank = packOp.getSourceRank();
274
275 if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
276 srcRank))
277 return rewriter.notifyMatchFailure(
278 transposeOp,
279 "Cannot fold in tensor.pack if a tile dimension was transposed "
280 "with a non-tile dimension in linalg.transpose.");
281
282 // Process transpose operation for tiled inner dimensions
283 for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
284 int64_t remappedPosition = transposePerm[i] - srcRank;
285 newMixedInnerTilesVec.push_back(Elt: mixedInnerTiles[remappedPosition]);
286 newInnerDimsPosVec.push_back(Elt: innerDimsPos[remappedPosition]);
287 }
288
289 Value output = packOp.createDestinationTensor(
290 rewriter, transposeOp.getLoc(), packOp.getSource(),
291 newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
292
293 rewriter.replaceOpWithNewOp<PackOp>(
294 transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
295 newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
296
297 return success();
298 }
299};
300
301/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
302/// semantics.
303struct FoldConsumerPackWithProducerLinalgTransposeOp
304 : public OpRewritePattern<PackOp> {
305 using OpRewritePattern<PackOp>::OpRewritePattern;
306
307 LogicalResult matchAndRewrite(PackOp packOp,
308 PatternRewriter &rewriter) const override {
309 auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
310
311 if (!transposeOp)
312 return failure();
313
314 auto transposePermutation = transposeOp.getPermutation();
315 auto outerDimsPerm = packOp.getOuterDimsPerm();
316 auto innerDimsPos = packOp.getInnerDimsPos();
317 SmallVector<int64_t> newInnerDimsPosVec;
318 SmallVector<int64_t> newOuterDimsPermVec =
319 llvm::to_vector(transposePermutation);
320
321 if (!outerDimsPerm.empty())
322 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
323
324 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
325 // permutation rank won't necessarily be equal in all cases.
326 for (auto dim : innerDimsPos)
327 newInnerDimsPosVec.push_back(transposePermutation[dim]);
328
329 Value output = packOp.createDestinationTensor(
330 rewriter, packOp.getLoc(), transposeOp.getOperand(0),
331 packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
332
333 rewriter.replaceOpWithNewOp<PackOp>(
334 packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
335 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
336
337 return success();
338 }
339};
340
341/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
342/// transpose semantics.
343struct FoldProducerUnPackWithConsumerLinalgTransposeOp
344 : public OpRewritePattern<linalg::TransposeOp> {
345 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
346
347 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
348 PatternRewriter &rewriter) const override {
349 auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
350
351 if (!unPackOp)
352 return failure();
353
354 auto transposePermutation = transposeOp.getPermutation();
355 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
356 auto innerDimsPos = unPackOp.getInnerDimsPos();
357 SmallVector<int64_t> newInnerDimsPosVec;
358 SmallVector<int64_t> newOuterDimsPermVec =
359 llvm::to_vector(transposePermutation);
360
361 if (!outerDimsPerm.empty())
362 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
363
364 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
365 // permutation rank won't necessarily be equal in all cases.
366 for (auto dim : innerDimsPos)
367 newInnerDimsPosVec.push_back(transposePermutation[dim]);
368
369 Value output = unPackOp.createDestinationTensor(
370 rewriter, transposeOp.getLoc(), unPackOp.getSource(),
371 unPackOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
372
373 rewriter.replaceOpWithNewOp<UnPackOp>(
374 transposeOp, unPackOp.getSource(), output, newInnerDimsPosVec,
375 unPackOp.getMixedTiles(), newOuterDimsPermVec);
376
377 return success();
378 }
379};
380
381/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
382/// transpose semantics.
383struct FoldConsumerUnPackWithProducerLinalgTransposeOp
384 : public OpRewritePattern<UnPackOp> {
385 using OpRewritePattern<UnPackOp>::OpRewritePattern;
386
387 LogicalResult matchAndRewrite(UnPackOp unPackOp,
388 PatternRewriter &rewriter) const override {
389 auto transposeOp =
390 unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
391
392 if (!transposeOp)
393 return failure();
394
395 auto transposePermutation = transposeOp.getPermutation();
396 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
397 auto innerDimsPos = unPackOp.getInnerDimsPos();
398 int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
399 auto mixedInnerTilesVec = unPackOp.getMixedTiles();
400 SmallVector<int64_t> newOuterDimsPermVec;
401 SmallVector<int64_t> newInnerDimsPosVec;
402 SmallVector<OpFoldResult> newMixedInnerTilesVec;
403
404 if (!checkAndPermute(transposePermutation, outerDimsPerm,
405 newOuterDimsPermVec, destRank))
406 return rewriter.notifyMatchFailure(
407 unPackOp,
408 "Cannot fold in tensor.unpack if a tile dimension was transposed "
409 "with a non-tile dimension in linalg.transpose.");
410
411 // Process transpose operation for tiled inner dimensions
412 for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
413 int64_t remappedPosition = transposePermutation[i] - destRank;
414 newMixedInnerTilesVec.push_back(Elt: mixedInnerTilesVec[remappedPosition]);
415 newInnerDimsPosVec.push_back(Elt: innerDimsPos[remappedPosition]);
416 }
417
418 Value output = unPackOp.createDestinationTensor(
419 rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
420 newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
421
422 rewriter.replaceOpWithNewOp<UnPackOp>(
423 unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
424 newMixedInnerTilesVec, newOuterDimsPermVec);
425
426 return success();
427 }
428};
429} // namespace
430
431void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
432 patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
433 FoldProducerPackWithConsumerLinalgTransposeOp,
434 FoldConsumerPackWithProducerLinalgTransposeOp,
435 FoldConsumerUnPackWithProducerLinalgTransposeOp,
436 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
437 arg: patterns.getContext());
438}
439
440void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
441 patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
442 arg: patterns.getContext());
443}
444
445} // namespace tensor
446} // namespace mlir
447

source code of mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp