1//===- DataLayoutPropagation.cpp -----------------------------------------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Passes.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Linalg/IR/Linalg.h"
13#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14#include "mlir/Dialect/Linalg/Utils/Utils.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/Dialect/Tensor/Utils/Utils.h"
17#include "mlir/Dialect/Utils/IndexingUtils.h"
18#include "mlir/IR/Dominance.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20#include "llvm/ADT/TypeSwitch.h"
21#include "llvm/Support/Debug.h"
22#include <optional>
23
24namespace mlir {
25#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
26#include "mlir/Dialect/Linalg/Passes.h.inc"
27} // namespace mlir
28
29using namespace mlir;
30using namespace mlir::linalg;
31
32#define DEBUG_TYPE "linalg-data-layout-propagation"
33
34namespace {
35
36static bool hasGatherSemantics(linalg::GenericOp genericOp) {
37 for (Operation &op : genericOp.getBody()->getOperations())
38 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
39 return true;
40 return false;
41}
42
43// The struct contains the infomation about mapping packing information to
44// the iteration domain of Linalg ops.
45struct PackInfo {
46 int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
47 // InnerDimsPos on iteration domain, which follows the order in pack ops.
48 SmallVector<int64_t> tiledDimsPos;
49 // The sizes of tiling data dimensions on iteration domain.
50 llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
51 // The mapping from a dimension of iteration domain to the corresponding inner
52 // tiling dimension on iteration domain.
53 llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
54 // The permutation of outer dims (on domain).
55 SmallVector<int64_t> outerDimsOnDomainPerm;
56};
57
58template <typename OpTy>
59static FailureOr<PackInfo>
60getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
61 OpTy packOrUnPackOp) {
62 static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
63 "applies to only pack or unpack operations");
64 LLVM_DEBUG(
65 { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
66
67 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
68 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
69 SmallVector<utils::IteratorType> iterators =
70 genericOp.getIteratorTypesArray();
71
72 PackInfo packInfo;
73 int64_t origNumDims = indexingMap.getNumDims();
74 SmallVector<AffineExpr> exprs(indexingMap.getResults());
75 ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
76 for (auto [index, innerDimPos, tileSize] :
77 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
78 innerDimsPos, packOrUnPackOp.getMixedTiles())) {
79 auto expr = exprs[innerDimPos];
80 if (!isa<AffineDimExpr>(expr))
81 return failure();
82 int64_t domainDimPos =
83 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
84 if (!isParallelIterator(iterators[domainDimPos]))
85 return failure();
86 packInfo.tiledDimsPos.push_back(domainDimPos);
87 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
88 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
89 LLVM_DEBUG({
90 llvm::dbgs() << "map innerDimPos=" << innerDimPos
91 << " to iteration dimension (d" << domainDimPos << ", d"
92 << packInfo.tileToPointMapping[domainDimPos]
93 << "), which has size=("
94 << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
95 });
96 }
97
98 // Bail out if a tiled dimension is present in a map but not as an affine dim
99 // expression.
100 auto areAllAffineDimExpr = [&](int dim) {
101 for (AffineMap map : indexingMaps) {
102 if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
103 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
104 })) {
105 return false;
106 }
107 }
108 return true;
109 };
110 for (int64_t i : packInfo.tiledDimsPos)
111 if (!areAllAffineDimExpr(i))
112 return failure();
113
114 // Get the outer dims perm on the iteration domain. Start by identifying the
115 // set of domain dims affected by the outer permutation along with the
116 // permuted ordering for those dims. Then the full outer dims permutation can
117 // be constructed by replacing the affected dims with the permuted result in a
118 // numLoops-rank identity. e.g.
119 // outerDimsPerm = [1, 2, 0]
120 // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
121 //
122 // permutedOuterDims = [4, 3, 1]
123 // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
124 //
125 // Non-affine dim expressions must not be permuted by the outer dims
126 // permutation.
127 SmallVector<int64_t> permutedOuterDims;
128 for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
129 auto permutedExpr = indexingMap.getResult(idx: dim);
130 if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
131 permutedOuterDims.push_back(dimExpr.getPosition());
132 continue;
133 }
134
135 // TODO: Allow propagation with transposes on non affine dim expressions,
136 // e.g. d0 + d1 which implies transposing both dims simultaneously while
137 // maintaining the relative position between them.
138 if (static_cast<int64_t>(index) != dim)
139 return failure();
140 }
141 if (!permutedOuterDims.empty()) {
142 int64_t outerDimIndex = 0;
143 llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
144 permutedOuterDims.end());
145 for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
146 packInfo.outerDimsOnDomainPerm.push_back(
147 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
148 : i);
149 LLVM_DEBUG({
150 llvm::dbgs() << "map outer dimsDimsPerm to ";
151 for (auto dim : packInfo.outerDimsOnDomainPerm)
152 llvm::dbgs() << dim << " ";
153 llvm::dbgs() << "\n";
154 });
155 }
156
157 return packInfo;
158}
159
160static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
161 ArrayRef<AffineExpr> exprs) {
162 // Compute `outer_dims_perm`. See example:
163 // current exprs : (d0, d1, d2, d3) -> (d2, d3)
164 // perm : [0, 3, 1, 2]
165 // First map d2, d3 with their position in the array as:
166 // currentPositionTileLoops: dim | pos
167 // d2 | 0
168 // d3 | 1
169 // then scan `perm` in order and get the `outer_dims_perm`
170 // to be used, here it would be [1, 0].
171 assert(!perm.empty() && "expect perm not to be empty");
172 assert(!exprs.empty() && "expect exprs not to be empty");
173 if (exprs.size() == 1)
174 return {};
175 SmallVector<int64_t> outerDimsPerm;
176 DenseMap<int64_t, int64_t> currentPositionTileLoops;
177 for (auto [pos, expr] : llvm::enumerate(exprs)) {
178 // Here we rely on the assumption that the outer dims permutation
179 // when propagating currently requires that non-affine dim expressions
180 // are not permuted, thus allowing the identity assignment below.
181 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
182 currentPositionTileLoops[dimExpr.getPosition()] = pos;
183 else
184 currentPositionTileLoops[pos] = pos;
185 }
186 for (int64_t loopIdx : perm) {
187 if (currentPositionTileLoops.count(loopIdx))
188 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
189 }
190 return outerDimsPerm;
191}
192
193/// Returns a tuple for packed operand and indexing_map with the assumptions:
194/// 1) The generic op is the producer of the pack op.
195/// 2) The generic op has only one result.
196/// If the operand is a scalar or packing dimensions are all irrelevant to the
197/// operand, the operand and the updated indexing map will be returned.
198/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
199///
200/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
201/// #map1 = affine_map<(d0, d1) -> (d0)>
202/// #map2 = affine_map<(d0, d1) -> (d1)>
203/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
204/// iterator_types = ["parallel", "parallel"]}
205/// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
206/// outs(%init : tensor<?x?xf32>) {
207/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
208/// %4 = arith.addf %arg3, %arg4 : f32
209/// linalg.yield %4 : f32
210/// } -> tensor<?x?xf32>
211/// %1 = tensor.pack %0
212/// inner_dims_pos = [0, 1]
213/// inner_tiles = [8, 2]
214/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
215///
216/// Taking the first input operand as an example, the inner tile size of d1 is
217/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
218/// affine_map<(d1, d3)>` will be returned.
219///
220/// %pack = tensor.pack %arg0
221/// inner_dims_pos = [0]
222/// inner_tiles = [8]
223/// into %init : tensor<?xf32> -> tensor<?x8xf32>
224static std::tuple<Value, AffineMap>
225getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
226 GenericOp genericOp, OpOperand *opOperand) {
227 int64_t numOrigLoops = genericOp.getNumLoops();
228 int64_t numInnerLoops = packInfo.getNumTiledLoops();
229 int64_t numLoops = numOrigLoops + numInnerLoops;
230 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
231 llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
232 SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
233
234 // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
235 if (genericOp.isScalar(opOperand) || exprs.empty())
236 return std::make_tuple(opOperand->get(),
237 AffineMap::get(numLoops, 0, exprs, b.getContext()));
238
239 // Step 1. Construct the information of packing data dimensions; append inner
240 // dimensions to the indexing maps for the operand.
241 for (auto [index, expr] : llvm::enumerate(exprs)) {
242 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
243 int64_t dimPos = dimExpr.getPosition();
244 domainDimToOperandDim[dimPos] = index;
245 continue;
246 }
247 }
248 SmallVector<int64_t> innerDimsPos;
249 SmallVector<OpFoldResult> innerTileSizes;
250 for (auto dimPos : packInfo.tiledDimsPos) {
251 if (!domainDimToOperandDim.count(dimPos))
252 continue;
253 int64_t index = domainDimToOperandDim[dimPos];
254 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
255 innerDimsPos.push_back(index);
256 exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
257 }
258
259 // Step 2. Handle outer dim permutations.
260 SmallVector<int64_t> outerDimsPerm;
261 if (!packInfo.outerDimsOnDomainPerm.empty()) {
262 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
263
264 // Step 2.1: Fold transpose into the linalg.generic.
265 SmallVector<int64_t> inversedOuterPerm =
266 invertPermutationVector(packInfo.outerDimsOnDomainPerm);
267 for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
268 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
269 int64_t dimPos = dimExpr.getPosition();
270 exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
271 continue;
272 }
273 assert(isa<AffineConstantExpr>(exprs[i]) &&
274 "Attempted to permute non-constant and non-affine dim expression");
275 }
276 // Step 2.2: Undo the transposition on `exprs` and propagate the
277 // transposition on the pack using outerDimsPerm.
278 if (!outerDimsPerm.empty()) {
279 SmallVector<AffineExpr> auxVec = exprs;
280 for (const auto &en : enumerate(outerDimsPerm))
281 auxVec[en.index()] = exprs[en.value()];
282 exprs = auxVec;
283 }
284 }
285 auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
286
287 // The operand does not have dimensions that relates to pack op.
288 if (innerDimsPos.empty() && outerDimsPerm.empty())
289 return std::make_tuple(opOperand->get(), indexingMap);
290
291 auto empty = tensor::PackOp::createDestinationTensor(
292 b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
293 auto packedOperand = b.create<tensor::PackOp>(
294 loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
295 /*padding=*/std::nullopt, outerDimsPerm);
296 return std::make_tuple(packedOperand, indexingMap);
297}
298
299/// Pack a genericOp and return it.
300static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
301 Value dest, AffineMap packedOutIndexingMap,
302 const PackInfo &packInfo) {
303 Location loc = genericOp.getLoc();
304 SmallVector<Value> inputOperands;
305 SmallVector<AffineMap> indexingMaps;
306 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
307 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
308 rewriter, loc, packInfo, genericOp, inputOperand);
309 inputOperands.push_back(packedOperand);
310 indexingMaps.push_back(packedIndexingMap);
311 }
312
313 int64_t numInnerLoops = packInfo.getNumTiledLoops();
314 SmallVector<utils::IteratorType> iterTypes =
315 genericOp.getIteratorTypesArray();
316 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
317
318 indexingMaps.push_back(packedOutIndexingMap);
319
320 auto newGenericOp = rewriter.create<linalg::GenericOp>(
321 loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
322 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
323 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
324 newGenericOp.getRegion().begin());
325 return newGenericOp;
326}
327
328/// Bubbles up tensor.pack op through a producer generic op. This
329/// swap pack(generic) to generic(pack). The new generic op works on packed
330/// domain; pack ops are created for input and output operands. E.g.,
331///
332/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
333/// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
334/// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
335/// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
336/// %3 = linalg.generic {indexing_maps = [#map0, #map0],
337/// iterator_types = ["parallel", "parallel"]}
338/// ins(%arg0 : tensor<?x?xf32>)
339/// outs(%2 : tensor<?x?xf32>) {
340/// ^bb0(%arg3: f32, %arg4: f32):
341/// %4 = arith.addf %arg3, %arg3 : f32
342/// linalg.yield %4 : f32
343/// } -> tensor<?x?xf32>
344/// %4 = tensor.pack %3
345/// inner_dims_pos = [0, 1]
346/// inner_tiles = [8, 2]
347/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
348///
349/// will be converted to
350///
351/// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
352/// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
353/// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
354/// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
355/// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
356/// %0 = affine.apply #map()[%dim]
357/// %1 = affine.apply #map1()[%dim_0]
358/// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
359/// %pack = tensor.pack %arg0
360/// inner_dims_pos = [0, 1]
361/// inner_tiles = [8, 2]
362/// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
363/// %3 = linalg.generic {indexing_maps = [#map2, #map2],
364/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
365/// ins(%pack : tensor<?x?x8x2xf32>)
366/// outs(%arg1 : tensor<?x?x8x2xf32>) {
367/// ^bb0(%in: f32, %out: f32):
368/// %4 = arith.addf %in, %in : f32
369/// linalg.yield %4 : f32
370/// } -> tensor<?x?x8x2xf32>
371static FailureOr<GenericOp>
372bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
373 const ControlPropagationFn &controlFn) {
374 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
375 if (!genericOp)
376 return failure();
377
378 // User controlled propagation function.
379 if (!controlFn(genericOp))
380 return failure();
381
382 // TODO: Enable propagation in the presence of linalg.index and
383 // tensor.extract, likely as a separate pattern as the pack information and
384 // propagation decision needs to be inferred from the region of the generic.
385 if (hasGatherSemantics(genericOp))
386 return failure();
387
388 // TODO: Relax the restriction. We are able to bubble up the pack op through
389 // multi-result generic op. It just needs more work.
390 if (genericOp.getNumResults() != 1)
391 return failure();
392
393 // Bail-out if the result of the generic has multiple uses, as bubbling up
394 // creates recomputation if the generic has multiple users.
395 // TODO: Enable the case where every use is an identical pack op as no
396 // recomputation is needed in that case.
397 if (!genericOp->getResult(0).hasOneUse())
398 return failure();
399
400 // We want to move the pack not the generic.
401 OpBuilder::InsertionGuard guard(rewriter);
402 rewriter.setInsertionPoint(genericOp);
403
404 // We need to handle two cases:
405 // 1) The tensor.pack destination is a tensor.empty. If this is the case, we
406 // create a new tensor.empty to avoid breaking dominance, as we are moving the
407 // tensor.pack above the linalg.generic.
408 // 2) The destination is not a tensor.empty. In this case we can replace only
409 // if the destination of the tensor.pack dominates the linalg.generic.
410 Value packOpDest = packOp.getDest();
411 if (!packOpDest.hasOneUse())
412 return failure();
413 if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
414 packOpDest = rewriter.create<tensor::EmptyOp>(
415 genericOp->getLoc(), emptyOp.getMixedSizes(),
416 emptyOp.getType().getElementType());
417 } else {
418 DominanceInfo dom(genericOp);
419 if (!dom.properlyDominates(packOpDest, genericOp))
420 return failure();
421 }
422
423 // TODO: Add an option for allowing padding values. It could introduce
424 // undefined behavior if we unconditionally propagate pack op through all
425 // the ops. E.g., if the padding value is zero and there are division ops in
426 // a generic op. Some values of padding area could be NaN (0/0).
427 if (packOp.getPaddingValue())
428 return failure();
429
430 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
431 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
432 if (failed(packInfo))
433 return failure();
434
435 // Rebuild the indexing map for the corresponding init operand.
436 auto [packedOutOperand, packedOutIndexingMap] =
437 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
438 genericOp, opOperand);
439
440 // If the dps init operand of the generic is a tensor.empty forward the pack
441 // op destination.
442 Value dest = packedOutOperand;
443 if (auto initTensor = genericOp.getDpsInitOperand(0)
444 ->get()
445 .getDefiningOp<tensor::EmptyOp>()) {
446 dest = packOpDest;
447 }
448 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
449 *packInfo);
450}
451
452/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
453struct BubbleUpPackOpThroughGenericOpPattern
454 : public OpRewritePattern<tensor::PackOp> {
455public:
456 BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
457 ControlPropagationFn fun)
458 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
459
460 LogicalResult matchAndRewrite(tensor::PackOp packOp,
461 PatternRewriter &rewriter) const override {
462 auto genericOp =
463 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
464 if (failed(genericOp))
465 return failure();
466 rewriter.replaceOp(packOp, genericOp->getResults());
467 return success();
468 }
469
470private:
471 ControlPropagationFn controlFn;
472};
473
474/// Propagate a tensor.pack operation up through a tensor.pad. The idea is to
475/// add as many zero padding dimensions in `high` and `low` based on the number
476/// of point loops.
477class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
478public:
479 BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
480 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
481
482 LogicalResult matchAndRewrite(tensor::PackOp packOp,
483 PatternRewriter &rewriter) const override {
484 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
485 if (!padOp)
486 return failure();
487
488 // User controlled propagation function.
489 if (!controlFn(padOp))
490 return failure();
491
492 if (!padOp.getResult().hasOneUse())
493 return failure();
494
495 // TODO: Enable padding when the padding values are the same.
496 if (packOp.getPaddingValue())
497 return failure();
498
499 // Fail for non-constant padding values. The body of the pad could
500 // depend on the padding indices and/or properties of the padded
501 // tensor so for now we fail.
502 // TODO: Support non-constant padding values.
503 Value paddingVal = padOp.getConstantPaddingValue();
504 if (!paddingVal)
505 return failure();
506
507 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
508 return failure();
509
510 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
511 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
512
513 // Bail out if one of the padded dimension is a tiled one.
514 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
515 llvm::SmallBitVector innerDims(paddedDims.size());
516 for (int64_t dim : innerDimsPos)
517 innerDims.flip(dim);
518 if (paddedDims.anyCommon(RHS: innerDims))
519 return failure();
520
521 Location loc = padOp->getLoc();
522 OpBuilder::InsertionGuard guard(rewriter);
523 rewriter.setInsertionPoint(padOp);
524
525 auto empty = tensor::PackOp::createDestinationTensor(
526 rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
527 outerDimsPerm);
528 Value packedSource = rewriter.create<tensor::PackOp>(
529 loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
530 /*padding=*/std::nullopt, outerDimsPerm);
531
532 // If we have `outer_dims_perms` we need to adjust the padded dimensions.
533 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
534 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
535 if (!outerDimsPerm.empty()) {
536 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
537 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
538 }
539 // The tiled dimensions were verified to be unpadded above, so here we
540 // just append 0 for the inner tile dimensions.
541 size_t pointLoopsSize = innerDimsPos.size();
542 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
543 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
544
545 auto newPadOp = rewriter.create<tensor::PadOp>(
546 loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
547 padOp.getNofold());
548 rewriter.replaceOp(packOp, newPadOp.getResult());
549 return success();
550 }
551
552private:
553 ControlPropagationFn controlFn;
554};
555
556/// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
557///
558/// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
559/// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
560/// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
561/// non-unit projected dims in pos [2, 3] is 2.
562///
563/// If all candidates in a reassociation are unit dims, it chooses the
564/// inner-most dim pos.
565static SmallVector<int64_t>
566projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
567 ArrayRef<ReassociationIndices> reassocIndices,
568 ArrayRef<int64_t> targetShape) {
569 SmallVector<int64_t> projectedDimsPos;
570 for (auto pos : dimsPos) {
571 // In the case all dims are unit, this will return the inner-most one.
572 int64_t projectedPos = reassocIndices[pos].back();
573 for (auto i : llvm::reverse(reassocIndices[pos])) {
574 int64_t dim = targetShape[i];
575 if (dim > 1 || ShapedType::isDynamic(dim)) {
576 projectedPos = i;
577 break;
578 }
579 }
580 projectedDimsPos.push_back(projectedPos);
581 }
582 return projectedDimsPos;
583}
584
585/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
586static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
587 ArrayRef<int64_t> shape,
588 ArrayRef<int64_t> tileSizes) {
589 for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
590 int64_t dim = shape[pos];
591 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
592 return false;
593 }
594 return true;
595}
596
597/// Permutate the reassociation indices and reindex them in the sequence order.
598/// Returns the next dim pos in the sequence.
599///
600/// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
601/// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
602/// [[0], [1, 2]].
603static int64_t applyPermutationAndReindexReassoc(
604 SmallVector<ReassociationIndices> &reassocIndices,
605 ArrayRef<int64_t> permutation) {
606 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
607 int64_t nextPos = 0;
608 for (ReassociationIndices &indices : reassocIndices) {
609 for (auto &index : indices) {
610 index = nextPos;
611 nextPos += 1;
612 }
613 }
614 return nextPos;
615}
616
617/// Bubble up pack op through collapse shape op when the packed dims can be
618/// projected to the dims before collapsing. This is possible when the inner
619/// tile sizes can divide the projected dims.
620///
621/// For example:
622///
623/// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
624/// : tensor<?x16x4xf32> into tensor<?x4xf32>
625/// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
626/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
627/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
628///
629/// can be transformed into:
630///
631/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
632/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
633/// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
634/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
635/// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
636static LogicalResult
637bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
638 tensor::PackOp packOp,
639 PatternRewriter &rewriter) {
640 SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
641 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
642 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
643
644 ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
645 SmallVector<ReassociationIndices> reassocIndices =
646 collapseOp.getReassociationIndices();
647 // Project inner tile pos to the dim pos before collapsing. For example, if
648 // dims [x, y] is collapsed into [z], packing on dim z can be projected back
649 // to pack on dim y.
650 //
651 // Project to inner-most non-unit dims to increase the chance that they can be
652 // divided by the inner tile sizes. This is correct because for [..., x, 1],
653 // packing on dim 1 is equivalent to packing on dim x.
654 SmallVector<int64_t> projectedInnerDimsPos =
655 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
656
657 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
658 innerTileSizes)) {
659 return failure();
660 }
661 // Expand the outer dims permutation with the associated source dims for the
662 // new permutation after bubbling. This is because moving a collapsed dim is
663 // equivalent to moving the associated source dims together.
664 SmallVector<int64_t> newOuterDimsPerm;
665 for (auto outerPos : outerDimsPerm) {
666 newOuterDimsPerm.insert(newOuterDimsPerm.end(),
667 reassocIndices[outerPos].begin(),
668 reassocIndices[outerPos].end());
669 }
670
671 auto emptyOp = tensor::PackOp::createDestinationTensor(
672 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
673 projectedInnerDimsPos, newOuterDimsPerm);
674 auto newPackOp = rewriter.create<tensor::PackOp>(
675 packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
676 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
677
678 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
679 // First apply the permutation on the reassociations of the outer dims.
680 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
681 // -> [[0], [1, 2]]
682 int64_t nextPos =
683 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
684 // Then add direct mapping for the inner tile dims.
685 for (size_t i = 0; i < innerDimsPos.size(); ++i) {
686 newReassocIndices.push_back({nextPos});
687 nextPos += 1;
688 }
689
690 auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
691 collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
692 rewriter.replaceOp(packOp, newCollapseOp);
693
694 return success();
695}
696
697class BubbleUpPackOpThroughReshapeOp final
698 : public OpRewritePattern<tensor::PackOp> {
699public:
700 BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
701 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
702
703 LogicalResult matchAndRewrite(tensor::PackOp packOp,
704 PatternRewriter &rewriter) const override {
705 Operation *srcOp = packOp.getSource().getDefiningOp();
706 // Currently only support when the pack op is the only user.
707 if (!srcOp || !(srcOp->getNumResults() == 1) ||
708 !srcOp->getResult(idx: 0).hasOneUse()) {
709 return failure();
710 }
711 // Currently only support static inner tile sizes.
712 if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
713 return ShapedType::isDynamic(size);
714 })) {
715 return failure();
716 }
717
718 // User controlled propagation function.
719 if (!controlFn(srcOp))
720 return failure();
721
722 return TypeSwitch<Operation *, LogicalResult>(srcOp)
723 .Case([&](tensor::CollapseShapeOp op) {
724 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
725 })
726 .Default([](Operation *) { return failure(); });
727 }
728
729private:
730 ControlPropagationFn controlFn;
731};
732
733/// Push down unpack op through expand shape op when the packed dims can be
734/// projected to the dims after expanding. This is possible when the inner tile
735/// sizes can divide the projected dims.
736///
737/// For example:
738///
739/// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
740/// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
741/// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
742/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
743/// : tensor<?x256xf32> into tensor<?x256x256xf32>
744///
745/// can be transformed into:
746///
747/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
748/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
749/// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
750/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
751/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
752static LogicalResult
753pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
754 tensor::ExpandShapeOp expandOp,
755 PatternRewriter &rewriter) {
756 SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
757 ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
758 ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
759
760 ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
761 SmallVector<ReassociationIndices> reassocIndices =
762 expandOp.getReassociationIndices();
763 // Project inner tile pos to the dim pos after expanding. For example, if dims
764 // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
765 // on dim y.
766 //
767 // Project to inner-most non-unit dims to increase the chance that they can be
768 // divided by the inner tile sizes. This is correct because for [..., x, 1],
769 // unpacking on dim 1 is equivalent to unpacking on dim x.
770 SmallVector<int64_t> projectedInnerDimsPos =
771 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
772
773 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
774 innerTileSizes)) {
775 return failure();
776 }
777 // Expand the outer dims permutation with the associated expanded dims for the
778 // new permutation after pushing. This is because moving a source dim is
779 // equivalent to moving the associated expanded dims together.
780 SmallVector<int64_t> newOuterDimsPerm;
781 for (auto outerPos : outerDimsPerm) {
782 newOuterDimsPerm.insert(newOuterDimsPerm.end(),
783 reassocIndices[outerPos].begin(),
784 reassocIndices[outerPos].end());
785 }
786
787 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
788 // First apply the permutation on the reassociations of the outer dims.
789 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
790 // -> [[0], [1, 2]]
791 int64_t nextPos =
792 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
793 // Then add direct mapping for the inner tile dims.
794 for (size_t i = 0; i < innerDimsPos.size(); ++i) {
795 newReassocIndices.push_back({nextPos});
796 nextPos += 1;
797 }
798
799 RankedTensorType newExpandType =
800 tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
801 projectedInnerDimsPos, newOuterDimsPerm);
802 auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
803 expandOp.getLoc(), newExpandType, unPackOp.getSource(),
804 newReassocIndices);
805
806 auto emptyOp = tensor::UnPackOp::createDestinationTensor(
807 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
808 projectedInnerDimsPos, newOuterDimsPerm);
809 auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
810 unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
811 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
812 rewriter.replaceOp(expandOp, newUnPackOp);
813
814 return success();
815}
816
817class PushDownUnPackOpThroughReshapeOp final
818 : public OpRewritePattern<tensor::UnPackOp> {
819public:
820 PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
821 ControlPropagationFn fun)
822 : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
823 }
824
825 LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
826 PatternRewriter &rewriter) const override {
827 Value result = unPackOp.getResult();
828 // Currently only support unpack op with the single user.
829 if (!result.hasOneUse()) {
830 return failure();
831 }
832 // Currently only support static inner tile sizes.
833 if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
834 return ShapedType::isDynamic(size);
835 })) {
836 return failure();
837 }
838
839 Operation *consumerOp = *result.user_begin();
840 // User controlled propagation function.
841 if (!controlFn(consumerOp))
842 return failure();
843
844 return TypeSwitch<Operation *, LogicalResult>(consumerOp)
845 .Case([&](tensor::ExpandShapeOp op) {
846 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
847 })
848 .Default([](Operation *) { return failure(); });
849 }
850
851private:
852 ControlPropagationFn controlFn;
853};
854
855// TODO: Relax this restriction. We should unpack a generic op also
856// in the presence of multiple unpack ops as producers.
857/// Return the unpacked operand, if present, for the current generic op.
858static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
859 OpOperand *unPackedOperand = nullptr;
860 for (OpOperand &operand : genericOp->getOpOperands()) {
861 auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
862 if (!unPackOp)
863 continue;
864 if (unPackedOperand)
865 return failure();
866 unPackedOperand = &operand;
867 }
868 if (!unPackedOperand)
869 return failure();
870 return unPackedOperand;
871}
872
873/// Push down a tensor.unpack op through a generic op.
874/// The new generic op works on packed domain; pack ops are created for input
875/// and output operands. A tensor.unpack op is inserted right after the packed
876/// generic. E.g.
877///
878/// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
879///
880/// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
881///
882/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
883/// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
884/// inner_dims_pos = [3] inner_tiles = [32] into %0
885/// %2 = linalg.generic {indexing_maps = [#map],
886/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
887/// outs(%1 : tensor<12x56x56x64xf32>) {
888/// ^bb0(%out : f32):
889/// linalg.yield %out : f32
890/// } -> tensor<12x56x56x64xf32>
891///
892/// will be converted to
893///
894/// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
895///
896/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
897/// %1 = linalg.generic {indexing_maps = [#map],
898/// iterator_types = ["parallel", "parallel", "parallel",
899/// "parallel", "parallel"]}
900/// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
901/// ^bb0(%out : f32):
902/// linalg.yield %out : f32
903/// } -> tensor<12x2x56x56x32xf32>
904/// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
905/// inner_dims_pos = [3] inner_tiles = [32] into %0
906///
907static FailureOr<std::tuple<GenericOp, Value>>
908pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
909 if (genericOp.getNumResults() != 1)
910 return failure();
911
912 if (hasGatherSemantics(genericOp))
913 return failure();
914
915 // Collect the unPacked operand, if present.
916 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
917 if (failed(maybeUnPackedOperand))
918 return failure();
919 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
920
921 // Extract packing information.
922 tensor::UnPackOp producerUnPackOp =
923 unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
924 assert(producerUnPackOp && "expect a valid UnPackOp");
925 auto packInfo =
926 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
927 if (failed(packInfo))
928 return failure();
929
930 // Rebuild the indexing map for the corresponding init operand.
931 auto [packedOutOperand, packedOutIndexingMap] =
932 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
933 genericOp, genericOp.getDpsInitOperand(0));
934 auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
935
936 // If the dps init operand of the generic is a tensor.empty, do not pack it
937 // and forward the new tensor.empty as a destination.
938 Value dest = packedOutOperand;
939 if (auto initTensor = genericOp.getDpsInitOperand(0)
940 ->get()
941 .getDefiningOp<tensor::EmptyOp>()) {
942 if (destPack)
943 dest = destPack.getDest();
944 }
945
946 // Pack the genericOp.
947 GenericOp newGenericOp =
948 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
949 Value newResult =
950 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
951
952 // If the output is unaffected, no need to unpack.
953 if (!destPack)
954 return std::make_tuple(newGenericOp, newResult);
955
956 auto mixedTiles = destPack.getMixedTiles();
957 auto innerDimsPos = destPack.getInnerDimsPos();
958 auto outerDimsPerm = destPack.getOuterDimsPerm();
959
960 // If the output type for the generic differs from the source
961 // unpack op, we need to create a new destination tensor. In the
962 // dynamic case we always need a new destination.
963 auto loc = genericOp.getLoc();
964 Value unPackDest = producerUnPackOp.getDest();
965 auto genericOutType =
966 cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
967 if (producerUnPackOp.getDestType() != genericOutType ||
968 !genericOutType.hasStaticShape()) {
969 unPackDest = tensor::UnPackOp::createDestinationTensor(
970 rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
971 }
972
973 // Insert an unPackOp right after the packed generic.
974 Value unPackOpRes =
975 rewriter
976 .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
977 mixedTiles, outerDimsPerm)
978 .getResult();
979
980 return std::make_tuple(newGenericOp, unPackOpRes);
981}
982
983// Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
984struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
985public:
986 PushDownUnPackOpThroughGenericOp(MLIRContext *context,
987 ControlPropagationFn fun)
988 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
989
990 LogicalResult matchAndRewrite(GenericOp genericOp,
991 PatternRewriter &rewriter) const override {
992 if (!controlFn(genericOp))
993 return failure();
994
995 auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
996 if (failed(genericAndRepl))
997 return failure();
998 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
999 return success();
1000 }
1001
1002private:
1003 ControlPropagationFn controlFn;
1004};
1005
1006/// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
1007/// add as many zero padding dimensions in `high` and `low` based on the number
1008/// of point loops.
1009struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1010 PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1011 : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1012
1013 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1014 PatternRewriter &rewriter) const override {
1015 tensor::UnPackOp unpackOp =
1016 padOp.getSource().getDefiningOp<tensor::UnPackOp>();
1017 if (!unpackOp)
1018 return failure();
1019
1020 if (!controlFn(padOp))
1021 return failure();
1022
1023 Location loc = padOp.getLoc();
1024 // Bail out if one of the padded dimension is a tiled one.
1025 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1026 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1027 llvm::SmallBitVector innerDims(paddedDims.size());
1028 for (int64_t dim : innerDimsPos)
1029 innerDims.flip(dim);
1030 if (paddedDims.anyCommon(RHS: innerDims))
1031 return failure();
1032
1033 Value paddingVal = padOp.getConstantPaddingValue();
1034 if (!paddingVal)
1035 return failure();
1036
1037 // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1038 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1039 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1040 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1041 if (!outerDimsPerm.empty()) {
1042 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1043 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1044 }
1045 // Add zero padding for the point loops.
1046 size_t pointLoopsSize = innerDimsPos.size();
1047 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1048 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1049
1050 auto newPadOp = rewriter.create<tensor::PadOp>(
1051 loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
1052 paddingVal, padOp.getNofold());
1053
1054 // Inject the tensor.unpack right after the packed padOp.
1055 Value outputUnPack = rewriter.create<tensor::EmptyOp>(
1056 loc, padOp.getResultType().getShape(),
1057 padOp.getResultType().getElementType());
1058
1059 Value replacement = rewriter.create<tensor::UnPackOp>(
1060 loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1061 unpackOp.getMixedTiles(), outerDimsPerm);
1062 rewriter.replaceOp(padOp, replacement);
1063 return success();
1064 }
1065
1066private:
1067 ControlPropagationFn controlFn;
1068};
1069
1070} // namespace
1071
1072void mlir::linalg::populateDataLayoutPropagationPatterns(
1073 RewritePatternSet &patterns,
1074 const ControlPropagationFn &controlPackUnPackPropagation) {
1075 patterns
1076 .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1077 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1078 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1079 arg: patterns.getContext(), args: controlPackUnPackPropagation);
1080}
1081

source code of mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp