1//===- DataLayoutPropagation.cpp -----------------------------------------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/IR/Linalg.h"
10#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11#include "mlir/Dialect/Linalg/Utils/Utils.h"
12#include "mlir/Dialect/Tensor/IR/Tensor.h"
13#include "mlir/Dialect/Utils/IndexingUtils.h"
14#include "mlir/IR/Dominance.h"
15#include "llvm/ADT/SetOperations.h"
16#include "llvm/ADT/SetVector.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/Debug.h"
19#include <optional>
20
21namespace mlir {
22#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
23#include "mlir/Dialect/Linalg/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::linalg;
28
29#define DEBUG_TYPE "linalg-data-layout-propagation"
30
31namespace {
32
33static bool hasGatherSemantics(linalg::GenericOp genericOp) {
34 for (Operation &op : genericOp.getBody()->getOperations())
35 if (isa<tensor::ExtractOp, linalg::IndexOp>(Val: op))
36 return true;
37 return false;
38}
39
40// The struct contains the infomation about mapping packing information to
41// the iteration domain of Linalg ops.
42struct PackInfo {
43 int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
44 // InnerDimsPos on iteration domain, which follows the order in pack ops.
45 SmallVector<int64_t> tiledDimsPos;
46 // The sizes of tiling data dimensions on iteration domain.
47 llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
48 // The mapping from a dimension of iteration domain to the corresponding inner
49 // tiling dimension on iteration domain.
50 llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
51 // The permutation of outer dims (on domain).
52 SmallVector<int64_t> outerDimsOnDomainPerm;
53};
54
55template <typename OpTy>
56static FailureOr<PackInfo>
57getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
58 OpTy packOrUnPackOp) {
59 static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
60 "applies to only pack or unpack operations");
61 LLVM_DEBUG(
62 { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
63
64 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
65 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
66 SmallVector<utils::IteratorType> iterators =
67 genericOp.getIteratorTypesArray();
68
69 PackInfo packInfo;
70 int64_t origNumDims = indexingMap.getNumDims();
71 SmallVector<AffineExpr> exprs(indexingMap.getResults());
72 ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
73 for (auto [index, innerDimPos, tileSize] :
74 llvm::zip_equal(llvm::seq<unsigned>(Begin: 0, End: innerDimsPos.size()),
75 innerDimsPos, packOrUnPackOp.getMixedTiles())) {
76 auto expr = exprs[innerDimPos];
77 if (!isa<AffineDimExpr>(expr))
78 return failure();
79 int64_t domainDimPos =
80 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
81 if (!isParallelIterator(iteratorType: iterators[domainDimPos]))
82 return failure();
83 packInfo.tiledDimsPos.push_back(Elt: domainDimPos);
84 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
85 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
86 LLVM_DEBUG({
87 llvm::dbgs() << "map innerDimPos=" << innerDimPos
88 << " to iteration dimension (d" << domainDimPos << ", d"
89 << packInfo.tileToPointMapping[domainDimPos]
90 << "), which has size=("
91 << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
92 });
93 }
94
95 // Bail out if a tiled dimension is present in a map but not as an affine dim
96 // expression.
97 auto areAllAffineDimExpr = [&](int dim) {
98 for (AffineMap map : indexingMaps) {
99 if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
100 return expr.isFunctionOfDim(position: dim) && !isa<AffineDimExpr>(Val: expr);
101 })) {
102 return false;
103 }
104 }
105 return true;
106 };
107 for (int64_t i : packInfo.tiledDimsPos)
108 if (!areAllAffineDimExpr(i))
109 return failure();
110
111 // Get the outer dims perm on the iteration domain. Start by identifying the
112 // set of domain dims affected by the outer permutation along with the
113 // permuted ordering for those dims. Then the full outer dims permutation can
114 // be constructed by replacing the affected dims with the permuted result in a
115 // numLoops-rank identity. e.g.
116 // outerDimsPerm = [1, 2, 0]
117 // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
118 //
119 // permutedOuterDims = [4, 3, 1]
120 // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
121 //
122 // Non-affine dim expressions must not be permuted by the outer dims
123 // permutation.
124 SmallVector<int64_t> permutedOuterDims;
125 for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
126 auto permutedExpr = indexingMap.getResult(idx: dim);
127 if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
128 permutedOuterDims.push_back(Elt: dimExpr.getPosition());
129 continue;
130 }
131
132 // TODO: Allow propagation with transposes on non affine dim expressions,
133 // e.g. d0 + d1 which implies transposing both dims simultaneously while
134 // maintaining the relative position between them.
135 if (static_cast<int64_t>(index) != dim)
136 return failure();
137 }
138 if (!permutedOuterDims.empty()) {
139 int64_t outerDimIndex = 0;
140 llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
141 permutedOuterDims.end());
142 for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
143 packInfo.outerDimsOnDomainPerm.push_back(
144 Elt: permutedDomainDims.contains(V: i) ? permutedOuterDims[outerDimIndex++]
145 : i);
146 LLVM_DEBUG({
147 llvm::dbgs() << "map outer dimsDimsPerm to ";
148 for (auto dim : packInfo.outerDimsOnDomainPerm)
149 llvm::dbgs() << dim << " ";
150 llvm::dbgs() << "\n";
151 });
152 }
153
154 return packInfo;
155}
156
157static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
158 ArrayRef<AffineExpr> exprs) {
159 // Compute `outer_dims_perm`. See example:
160 // current exprs : (d0, d1, d2, d3) -> (d2, d3)
161 // perm : [0, 3, 1, 2]
162 // First map d2, d3 with their position in the array as:
163 // currentPositionTileLoops: dim | pos
164 // d2 | 0
165 // d3 | 1
166 // then scan `perm` in order and get the `outer_dims_perm`
167 // to be used, here it would be [1, 0].
168 assert(!perm.empty() && "expect perm not to be empty");
169 assert(!exprs.empty() && "expect exprs not to be empty");
170 if (exprs.size() == 1)
171 return {};
172 SmallVector<int64_t> outerDimsPerm;
173 DenseMap<int64_t, int64_t> currentPositionTileLoops;
174 for (auto [pos, expr] : llvm::enumerate(First&: exprs)) {
175 // Here we rely on the assumption that the outer dims permutation
176 // when propagating currently requires that non-affine dim expressions
177 // are not permuted, thus allowing the identity assignment below.
178 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val: expr))
179 currentPositionTileLoops[dimExpr.getPosition()] = pos;
180 else
181 currentPositionTileLoops[pos] = pos;
182 }
183 for (int64_t loopIdx : perm) {
184 if (currentPositionTileLoops.count(Val: loopIdx))
185 outerDimsPerm.push_back(Elt: currentPositionTileLoops.lookup(Val: loopIdx));
186 }
187 return outerDimsPerm;
188}
189
190/// Returns a tuple for packed operand and indexing_map with the assumptions:
191/// 1) The generic op is the producer of the pack op.
192/// 2) The generic op has only one result.
193/// If the operand is a scalar or packing dimensions are all irrelevant to the
194/// operand, the operand and the updated indexing map will be returned.
195/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
196///
197/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
198/// #map1 = affine_map<(d0, d1) -> (d0)>
199/// #map2 = affine_map<(d0, d1) -> (d1)>
200/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
201/// iterator_types = ["parallel", "parallel"]}
202/// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
203/// outs(%init : tensor<?x?xf32>) {
204/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
205/// %4 = arith.addf %arg3, %arg4 : f32
206/// linalg.yield %4 : f32
207/// } -> tensor<?x?xf32>
208/// %1 = linalg.pack %0
209/// inner_dims_pos = [0, 1]
210/// inner_tiles = [8, 2]
211/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
212///
213/// Taking the first input operand as an example, the inner tile size of d1 is
214/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
215/// affine_map<(d1, d3)>` will be returned.
216///
217/// %pack = linalg.pack %arg0
218/// inner_dims_pos = [0]
219/// inner_tiles = [8]
220/// into %init : tensor<?xf32> -> tensor<?x8xf32>
221static std::tuple<Value, AffineMap>
222getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
223 GenericOp genericOp, OpOperand *opOperand) {
224 int64_t numOrigLoops = genericOp.getNumLoops();
225 int64_t numInnerLoops = packInfo.getNumTiledLoops();
226 int64_t numLoops = numOrigLoops + numInnerLoops;
227 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
228 llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
229 SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
230
231 // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
232 if (genericOp.isScalar(opOperand) || exprs.empty())
233 return std::make_tuple(args: opOperand->get(),
234 args: AffineMap::get(dimCount: numLoops, symbolCount: 0, results: exprs, context: b.getContext()));
235
236 // Step 1. Construct the information of packing data dimensions; append inner
237 // dimensions to the indexing maps for the operand.
238 for (auto [index, expr] : llvm::enumerate(First&: exprs)) {
239 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) {
240 int64_t dimPos = dimExpr.getPosition();
241 domainDimToOperandDim[dimPos] = index;
242 continue;
243 }
244 }
245 SmallVector<int64_t> innerDimsPos;
246 SmallVector<OpFoldResult> innerTileSizes;
247 for (auto dimPos : packInfo.tiledDimsPos) {
248 if (!domainDimToOperandDim.count(Val: dimPos))
249 continue;
250 int64_t index = domainDimToOperandDim[dimPos];
251 innerTileSizes.push_back(Elt: packInfo.domainDimAndTileMapping[dimPos]);
252 innerDimsPos.push_back(Elt: index);
253 exprs.push_back(Elt: b.getAffineDimExpr(position: packInfo.tileToPointMapping[dimPos]));
254 }
255
256 // Step 2. Handle outer dim permutations.
257 SmallVector<int64_t> outerDimsPerm;
258 if (!packInfo.outerDimsOnDomainPerm.empty()) {
259 outerDimsPerm = computeOuterDims(perm: packInfo.outerDimsOnDomainPerm, exprs);
260
261 // Step 2.1: Fold transpose into the linalg.generic.
262 SmallVector<int64_t> inversedOuterPerm =
263 invertPermutationVector(permutation: packInfo.outerDimsOnDomainPerm);
264 for (auto i : llvm::seq<unsigned>(Begin: 0, End: origIndexingMap.getNumResults())) {
265 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: exprs[i])) {
266 int64_t dimPos = dimExpr.getPosition();
267 exprs[i] = b.getAffineDimExpr(position: inversedOuterPerm[dimPos]);
268 continue;
269 }
270 assert(isa<AffineConstantExpr>(exprs[i]) &&
271 "Attempted to permute non-constant and non-affine dim expression");
272 }
273 // Step 2.2: Undo the transposition on `exprs` and propagate the
274 // transposition on the pack using outerDimsPerm.
275 if (!outerDimsPerm.empty()) {
276 SmallVector<AffineExpr> auxVec = exprs;
277 for (const auto &en : enumerate(First&: outerDimsPerm))
278 auxVec[en.index()] = exprs[en.value()];
279 exprs = auxVec;
280 }
281 }
282 auto indexingMap = AffineMap::get(dimCount: numLoops, symbolCount: 0, results: exprs, context: b.getContext());
283
284 // The operand does not have dimensions that relates to pack op.
285 if (innerDimsPos.empty() && outerDimsPerm.empty())
286 return std::make_tuple(args: opOperand->get(), args&: indexingMap);
287
288 auto empty = linalg::PackOp::createDestinationTensor(
289 b, loc, source: opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
290 auto packedOperand = b.create<linalg::PackOp>(
291 location: loc, args: opOperand->get(), args&: empty, args&: innerDimsPos, args&: innerTileSizes,
292 /*padding=*/args: std::nullopt, args&: outerDimsPerm);
293 return std::make_tuple(args&: packedOperand, args&: indexingMap);
294}
295
296/// This function is a helper subroutine to pack a genericOp and return it. It
297/// will create a new generic op with the packed operand and the packed output
298/// according to packInfo when we attempt to push down unpack or bubble up pack
299/// around it. Implicitly this will only work when a packInfo can be obtained.
300/// This make sure that we are only using this function on parallel permuted
301/// dimensions.
302static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
303 Value dest, AffineMap packedOutIndexingMap,
304 const PackInfo &packInfo,
305 bool isFoldableUnpackPack) {
306 Location loc = genericOp.getLoc();
307 SmallVector<Value> inputOperands;
308 SmallVector<Value> inputOperandsFromUnpackedSource;
309 SmallVector<AffineMap> indexingMaps;
310 auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
311 return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
312 packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
313 llvm::equal(LRange: packOp.getMixedTiles(), RRange: unPackOp.getMixedTiles());
314 };
315 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
316 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
317 b&: rewriter, loc, packInfo, genericOp, opOperand: inputOperand);
318 auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
319 auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
320 if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
321 inputOperandsFromUnpackedSource.push_back(Elt: unpackOp.getSource());
322 } else {
323 inputOperandsFromUnpackedSource.push_back(Elt: packedOperand);
324 }
325 inputOperands.push_back(Elt: packedOperand);
326 indexingMaps.push_back(Elt: packedIndexingMap);
327 }
328
329 // If the unpack->pack sequences can be folded, replace use the sources of
330 // the unpack ops in any unpack->pack chains on the generic op operands.
331 if (isFoldableUnpackPack) {
332 inputOperands = inputOperandsFromUnpackedSource;
333 if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) {
334 auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
335 if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
336 dest = destUnPack.getSource();
337 }
338 }
339 }
340
341 int64_t numInnerLoops = packInfo.getNumTiledLoops();
342 SmallVector<utils::IteratorType> iterTypes =
343 genericOp.getIteratorTypesArray();
344 iterTypes.append(NumInputs: numInnerLoops, Elt: utils::IteratorType::parallel);
345
346 indexingMaps.push_back(Elt: packedOutIndexingMap);
347
348 auto newGenericOp = rewriter.create<linalg::GenericOp>(
349 location: loc, args: dest.getType(), args&: inputOperands, args&: dest, args&: indexingMaps, args&: iterTypes,
350 /*bodyBuild=*/args: nullptr, args: linalg::getPrunedAttributeList(op: genericOp));
351 rewriter.cloneRegionBefore(region&: genericOp.getRegion(), parent&: newGenericOp.getRegion(),
352 before: newGenericOp.getRegion().begin());
353 return newGenericOp;
354}
355
356static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
357 return llvm::all_of(Range: genericOp.getDpsInitsMutable(), P: [&](OpOperand &operand) {
358 return genericOp.getMatchingBlockArgument(opOperand: &operand).use_empty();
359 });
360}
361
362/// Bubbles up linalg.pack op through a producer generic op. This
363/// swap pack(generic) to generic(pack). The new generic op works on packed
364/// domain; pack ops are created for input and output operands. E.g.,
365///
366/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
367/// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
368/// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
369/// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
370/// %3 = linalg.generic {indexing_maps = [#map0, #map0],
371/// iterator_types = ["parallel", "parallel"]}
372/// ins(%arg0 : tensor<?x?xf32>)
373/// outs(%2 : tensor<?x?xf32>) {
374/// ^bb0(%arg3: f32, %arg4: f32):
375/// %4 = arith.addf %arg3, %arg3 : f32
376/// linalg.yield %4 : f32
377/// } -> tensor<?x?xf32>
378/// %4 = linalg.pack %3
379/// inner_dims_pos = [0, 1]
380/// inner_tiles = [8, 2]
381/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
382///
383/// will be converted to
384///
385/// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
386/// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
387/// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
388/// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
389/// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
390/// %0 = affine.apply #map()[%dim]
391/// %1 = affine.apply #map1()[%dim_0]
392/// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
393/// %pack = linalg.pack %arg0
394/// inner_dims_pos = [0, 1]
395/// inner_tiles = [8, 2]
396/// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
397/// %3 = linalg.generic {indexing_maps = [#map2, #map2],
398/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
399/// ins(%pack : tensor<?x?x8x2xf32>)
400/// outs(%arg1 : tensor<?x?x8x2xf32>) {
401/// ^bb0(%in: f32, %out: f32):
402/// %4 = arith.addf %in, %in : f32
403/// linalg.yield %4 : f32
404/// } -> tensor<?x?x8x2xf32>
405static FailureOr<GenericOp>
406bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
407 const ControlPropagationFn &controlFn) {
408 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
409 if (!genericOp)
410 return failure();
411
412 // User controlled propagation function.
413 if (!controlFn(&packOp.getSourceMutable()))
414 return failure();
415
416 // TODO: Enable propagation in the presence of linalg.index and
417 // tensor.extract, likely as a separate pattern as the pack information and
418 // propagation decision needs to be inferred from the region of the generic.
419 if (hasGatherSemantics(genericOp))
420 return failure();
421
422 // TODO: Relax the restriction. We are able to bubble up the pack op through
423 // multi-result generic op. It just needs more work.
424 if (genericOp.getNumResults() != 1)
425 return failure();
426
427 // Bail-out if the result of the generic has multiple uses, as bubbling up
428 // creates recomputation if the generic has multiple users.
429 // TODO: Enable the case where every use is an identical pack op as no
430 // recomputation is needed in that case.
431 if (!genericOp->getResult(idx: 0).hasOneUse())
432 return failure();
433
434 // TODO: Add an option for allowing padding values. It could introduce
435 // undefined behavior if we unconditionally propagate pack op through all
436 // the ops. E.g., if the padding value is zero and there are division ops in
437 // a generic op. Some values of padding area could be NaN (0/0).
438 if (packOp.getPaddingValue())
439 return failure();
440
441 OpOperand *opOperand = genericOp.getDpsInitOperand(i: 0);
442 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOrUnPackOp: packOp);
443 if (failed(Result: packInfo))
444 return failure();
445
446 // We want to move the pack not the generic.
447 OpBuilder::InsertionGuard guard(rewriter);
448 rewriter.setInsertionPoint(genericOp);
449
450 // We need to handle two cases:
451 // 1) The linalg.pack destination is a tensor.empty. If this is the case, we
452 // create a new tensor.empty to avoid breaking dominance, as we are moving the
453 // linalg.pack above the linalg.generic.
454 // 2) The destination is not a tensor.empty. In this case we can replace only
455 // if the destination of the linalg.pack dominates the linalg.generic.
456 Value packOpDest = packOp.getDest();
457 if (!packOpDest.hasOneUse())
458 return failure();
459 if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
460 packOpDest = rewriter.create<tensor::EmptyOp>(
461 location: genericOp->getLoc(), args: emptyOp.getMixedSizes(),
462 args: emptyOp.getType().getElementType());
463 } else {
464 DominanceInfo dom(genericOp);
465 if (!dom.properlyDominates(a: packOpDest, b: genericOp))
466 return failure();
467 }
468
469 // Rebuild the indexing map for the corresponding init operand.
470 auto [packedOutOperand, packedOutIndexingMap] =
471 getOrCreatePackedViewOfOperand(b&: rewriter, loc: genericOp.getLoc(), packInfo: *packInfo,
472 genericOp, opOperand);
473
474 // Forward the new tensor.empty as a destination if it is one of the following
475 // situations:
476 // 1) The dps init operand is a tensor.empty.
477 // 2) The dps init is a write-only operand, i.e., it is not used in the
478 // genericOp
479 Value dest = packedOutOperand;
480 auto initTensor =
481 genericOp.getDpsInitOperand(i: 0)->get().getDefiningOp<tensor::EmptyOp>();
482 if (initTensor || isGenericOutsNotUsed(genericOp)) {
483 dest = packOpDest;
484 }
485 // pack(unpack) isn't naively foldable because the unpack op can be from
486 // an arbitrary domain so we need to keep both.
487 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
488 packInfo: *packInfo, /*isFoldableUnpackPack=*/false);
489}
490
491/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
492struct BubbleUpPackOpThroughGenericOpPattern
493 : public OpRewritePattern<linalg::PackOp> {
494public:
495 BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
496 ControlPropagationFn fun)
497 : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
498
499 LogicalResult matchAndRewrite(linalg::PackOp packOp,
500 PatternRewriter &rewriter) const override {
501 auto genericOp =
502 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
503 if (failed(Result: genericOp))
504 return failure();
505 rewriter.replaceOp(op: packOp, newValues: genericOp->getResults());
506 return success();
507 }
508
509private:
510 ControlPropagationFn controlFn;
511};
512
513/// Propagate a linalg.pack operation up through a tensor.pad. The idea is to
514/// add as many zero padding dimensions in `high` and `low` based on the number
515/// of point loops.
516class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
517public:
518 BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
519 : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
520
521 LogicalResult matchAndRewrite(linalg::PackOp packOp,
522 PatternRewriter &rewriter) const override {
523 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
524 if (!padOp)
525 return failure();
526
527 // User controlled propagation function.
528 if (!controlFn(&packOp.getSourceMutable()))
529 return failure();
530
531 // TODO: Enable padding when the padding values are the same.
532 if (packOp.getPaddingValue())
533 return failure();
534
535 // Fail for non-constant padding values. The body of the pad could
536 // depend on the padding indices and/or properties of the padded
537 // tensor so for now we fail.
538 // TODO: Support non-constant padding values.
539 Value paddingVal = padOp.getConstantPaddingValue();
540 if (!paddingVal)
541 return failure();
542
543 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
544 return failure();
545
546 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
547
548 // Bail out if one of the padded dimension is a tiled one.
549 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
550 llvm::SmallBitVector innerDims(paddedDims.size());
551 for (int64_t dim : innerDimsPos)
552 innerDims.flip(Idx: dim);
553 if (paddedDims.anyCommon(RHS: innerDims))
554 return failure();
555
556 Location loc = padOp->getLoc();
557 OpBuilder::InsertionGuard guard(rewriter);
558 rewriter.setInsertionPoint(padOp);
559
560 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
561 SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
562 auto empty = linalg::PackOp::createDestinationTensor(
563 b&: rewriter, loc, source: padOp.getSource(), innerTileSizes: mixedTiles, innerDimsPos,
564 outerDimsPerm);
565 auto sourcePack = rewriter.create<linalg::PackOp>(
566 location: loc, args: padOp.getSource(), args&: empty, args&: innerDimsPos, args&: mixedTiles,
567 /*padding=*/args: std::nullopt, args&: outerDimsPerm);
568
569 // If we have `outer_dims_perms` we need to adjust the padded dimensions.
570 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
571 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
572 if (!outerDimsPerm.empty()) {
573 applyPermutationToVector<OpFoldResult>(inVec&: lowPad, permutation: outerDimsPerm);
574 applyPermutationToVector<OpFoldResult>(inVec&: highPad, permutation: outerDimsPerm);
575 }
576 // The tiled dimensions were verified to be unpadded above, so here we
577 // just append 0 for the inner tile dimensions.
578 size_t pointLoopsSize = innerDimsPos.size();
579 lowPad.append(NumInputs: pointLoopsSize, Elt: rewriter.getIndexAttr(value: 0));
580 highPad.append(NumInputs: pointLoopsSize, Elt: rewriter.getIndexAttr(value: 0));
581
582 auto newPadOp = rewriter.create<tensor::PadOp>(
583 location: loc, /*result=*/args: Type(), args&: sourcePack, args&: lowPad, args&: highPad, args&: paddingVal,
584 args: padOp.getNofold());
585
586 // If the pad has more than one user, create an unpack on the new pad to
587 // replace the other uses.
588 if (!padOp->hasOneUse()) {
589 auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
590 b&: rewriter, loc, source: newPadOp, innerTileSizes: mixedTiles, innerDimsPos, outerDimsPerm);
591 Value unpackedPad = rewriter.create<linalg::UnPackOp>(
592 location: loc, args&: newPadOp, args&: unpackEmpty, args&: innerDimsPos, args&: mixedTiles, args&: outerDimsPerm);
593 rewriter.replaceAllUsesExcept(from: padOp, to: unpackedPad, exceptedUser: sourcePack);
594 }
595
596 // Replace the pack with the new pad.
597 rewriter.replaceOp(op: packOp, newValues: newPadOp.getResult());
598
599 return success();
600 }
601
602private:
603 ControlPropagationFn controlFn;
604};
605
606/// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
607///
608/// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
609/// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
610/// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
611/// non-unit projected dims in pos [2, 3] is 2.
612///
613/// If all candidates in a reassociation are unit dims, it chooses the
614/// inner-most dim pos.
615static SmallVector<int64_t>
616projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
617 ArrayRef<ReassociationIndices> reassocIndices,
618 ArrayRef<int64_t> targetShape) {
619 SmallVector<int64_t> projectedDimsPos;
620 for (auto pos : dimsPos) {
621 // In the case all dims are unit, this will return the inner-most one.
622 int64_t projectedPos = reassocIndices[pos].back();
623 for (auto i : llvm::reverse(C: reassocIndices[pos])) {
624 int64_t dim = targetShape[i];
625 if (dim > 1 || ShapedType::isDynamic(dValue: dim)) {
626 projectedPos = i;
627 break;
628 }
629 }
630 projectedDimsPos.push_back(Elt: projectedPos);
631 }
632 return projectedDimsPos;
633}
634
635/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
636static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
637 ArrayRef<int64_t> shape,
638 ArrayRef<int64_t> tileSizes) {
639 for (auto [pos, tileSize] : llvm::zip_equal(t&: dimsPos, u&: tileSizes)) {
640 int64_t dim = shape[pos];
641 if (ShapedType::isDynamic(dValue: dim) || (dim % tileSize) != 0)
642 return false;
643 }
644 return true;
645}
646
647/// Permutate the reassociation indices and reindex them in the sequence order.
648/// Returns the next dim pos in the sequence.
649///
650/// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
651/// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
652/// [[0], [1, 2]].
653static int64_t applyPermutationAndReindexReassoc(
654 SmallVector<ReassociationIndices> &reassocIndices,
655 ArrayRef<int64_t> permutation) {
656 if (!permutation.empty())
657 applyPermutationToVector<ReassociationIndices>(inVec&: reassocIndices, permutation);
658 int64_t nextPos = 0;
659 for (ReassociationIndices &indices : reassocIndices) {
660 for (auto &index : indices) {
661 index = nextPos;
662 nextPos += 1;
663 }
664 }
665 return nextPos;
666}
667
668/// Bubble up pack op through collapse shape op when the packed dims can be
669/// projected to the dims before collapsing. This is possible when the inner
670/// tile sizes can divide the projected dims.
671///
672/// For example:
673///
674/// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
675/// : tensor<?x16x4xf32> into tensor<?x4xf32>
676/// %pack = linalg.pack %collapsed outer_dims_perm = [0, 1]
677/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
678/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
679///
680/// can be transformed into:
681///
682/// %pack = linalg.pack %in outer_dims_perm = [1, 2]
683/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
684/// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
685/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
686/// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
687static LogicalResult
688bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
689 linalg::PackOp packOp,
690 PatternRewriter &rewriter) {
691 SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
692 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
693 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
694
695 ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
696 SmallVector<ReassociationIndices> reassocIndices =
697 collapseOp.getReassociationIndices();
698 // Project inner tile pos to the dim pos before collapsing. For example, if
699 // dims [x, y] is collapsed into [z], packing on dim z can be projected back
700 // to pack on dim y.
701 //
702 // Project to inner-most non-unit dims to increase the chance that they can be
703 // divided by the inner tile sizes. This is correct because for [..., x, 1],
704 // packing on dim 1 is equivalent to packing on dim x.
705 SmallVector<int64_t> projectedInnerDimsPos =
706 projectToInnerMostNonUnitDimsPos(dimsPos: innerDimsPos, reassocIndices, targetShape: srcShape);
707
708 if (!isDimsDivisibleByTileSizes(dimsPos: projectedInnerDimsPos, shape: srcShape,
709 tileSizes: innerTileSizes)) {
710 return failure();
711 }
712 // Expand the outer dims permutation with the associated source dims for the
713 // new permutation after bubbling. This is because moving a collapsed dim is
714 // equivalent to moving the associated source dims together.
715 SmallVector<int64_t> newOuterDimsPerm;
716 for (auto outerPos : outerDimsPerm)
717 llvm::append_range(C&: newOuterDimsPerm, R&: reassocIndices[outerPos]);
718
719 auto emptyOp = linalg::PackOp::createDestinationTensor(
720 b&: rewriter, loc: packOp.getLoc(), source: collapseOp.getSrc(), innerTileSizes: packOp.getMixedTiles(),
721 innerDimsPos: projectedInnerDimsPos, outerDimsPerm: newOuterDimsPerm);
722 auto newPackOp = rewriter.create<linalg::PackOp>(
723 location: packOp.getLoc(), args: collapseOp.getSrc(), args&: emptyOp, args&: projectedInnerDimsPos,
724 args: packOp.getMixedTiles(), args: packOp.getPaddingValue(), args&: newOuterDimsPerm);
725
726 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
727 // First apply the permutation on the reassociations of the outer dims.
728 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
729 // -> [[0], [1, 2]]
730 int64_t nextPos =
731 applyPermutationAndReindexReassoc(reassocIndices&: newReassocIndices, permutation: outerDimsPerm);
732 // Then add direct mapping for the inner tile dims.
733 for (size_t i = 0; i < innerDimsPos.size(); ++i) {
734 newReassocIndices.push_back(Elt: {nextPos});
735 nextPos += 1;
736 }
737
738 auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
739 location: collapseOp.getLoc(), args: packOp.getType(), args&: newPackOp, args&: newReassocIndices);
740 rewriter.replaceOp(op: packOp, newOp: newCollapseOp);
741
742 return success();
743}
744
745/// Project dimsPos to their collapsed positions in the reassocIndices.
746///
747/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
748/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
749/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
750/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
751static SmallVector<int64_t>
752projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
753 ArrayRef<ReassociationIndices> reassocIndices) {
754 SmallVector<int64_t> projectedPos;
755
756 // Map each dimension to the position of corresponding reassociation index.
757 for (auto pos : dimsPos) {
758 for (auto [idx, indices] : llvm::enumerate(First&: reassocIndices)) {
759 // If the dimension is present in the current indices group, the group
760 // position within the reassociation map is the desired projected
761 // dimension position.
762 if (llvm::is_contained(Range: indices, Element: pos)) {
763 projectedPos.push_back(Elt: idx);
764 break;
765 }
766 }
767 }
768 assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
769
770 return projectedPos;
771}
772
773/// Bubble up pack op through expand shape op.
774///
775/// For example:
776///
777/// %expand = tensor.expand_shape %in [[0], [1, 2]]
778/// : tensor<?x64xf32> into tensor<?x4x16xf32>
779/// %pack = linalg.pack %expand outer_dims_perm = [0, 1]
780/// inner_dims_pos = [2] inner_tiles = [8] into %empty
781/// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
782///
783/// can be transformed into:
784///
785/// %pack = linalg.pack %in outer_dims_perm = [1, 2]
786/// inner_dims_pos = [1] inner_tiles = [8] into %empty
787/// : tensor<?x64xf32> -> tensor<?x8x8xf32>
788/// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
789/// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
790static LogicalResult
791bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
792 linalg::PackOp packOp,
793 PatternRewriter &rewriter) {
794 // Outer dimensions permutation is not supported currently.
795 // TODO: Handle outer_dims_perm variants.
796 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
797 if (!outerDimsPerm.empty() && !isIdentityPermutation(permutation: outerDimsPerm)) {
798 return rewriter.notifyMatchFailure(arg&: packOp,
799 msg: "non-identity outer dims perm NYI");
800 }
801
802 // Validate dimensions' relations between shape expansion and packing.
803 SmallVector<ReassociationIndices, 4> reassoc =
804 expandOp.getReassociationIndices();
805 ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
806 llvm::SetVector<int64_t> packDimsPos(llvm::from_range, packInnerDims);
807
808 for (auto [idx, indices] : llvm::enumerate(First&: reassoc)) {
809 // For each expand_shape reassociation, figure out which dimensions get
810 // packed if any.
811 llvm::SetVector<int64_t> expandDimPos(llvm::from_range, indices);
812 llvm::SetVector<int64_t> packedDims =
813 llvm::set_intersection(S1: packDimsPos, S2: expandDimPos);
814
815 // The expanded dimension is not packed so, it does not affect moving pack
816 // before shape expansion - simply continue.
817 if (packedDims.empty())
818 continue;
819 // Shape expansion cannot be propagated when multiple expanded dimension are
820 // packed - in this case operation reordering would affect final element
821 // positions and/or shapes can no longer be projected.
822 if (packedDims.size() != 1)
823 return rewriter.notifyMatchFailure(
824 arg&: packOp, msg: "only one of the expanded dimensions can be packed");
825 // Only the inner-most expanded dimension should be packed. Otherwise,
826 // elements order will be affected after operation reordering.
827 if (packedDims.front() != indices.back())
828 return rewriter.notifyMatchFailure(
829 arg&: packOp, msg: "can only pack the inner-most expanded dimension");
830 }
831
832 // Project pack.inner_dims_pos to positions before shape expansion.
833 SmallVector<int64_t> projectedInnerDimsPos =
834 projectDimsPosIntoReassocPos(dimsPos: packInnerDims, reassocIndices: reassoc);
835
836 // Project the shape expansion to new packed shape.
837 // The pack.outer_dims_perm is restricted to identity so, the permutation can
838 // be omitted for simplicity.
839 // TODO: Account for outer dimensions permutation.
840 //
841 // If reassociation is not possible, then reordering cannot happen.
842 // This can be caused by pack padding affecting previously expanded
843 // dimensions or packing extending dimensions.
844 RankedTensorType newPackType = linalg::PackOp::inferPackedType(
845 sourceType: expandOp.getSrcType(), innerTileSizes: packOp.getStaticInnerTiles(),
846 innerDimsPos: projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
847 auto reassocExpand =
848 getReassociationIndicesForReshape(sourceType: newPackType, targetType: packOp.getDestType());
849 if (!reassocExpand)
850 return rewriter.notifyMatchFailure(
851 arg&: packOp, msg: "could not reassociate dims after bubbling up");
852
853 Value destTensor = linalg::PackOp::createDestinationTensor(
854 b&: rewriter, loc: packOp.getLoc(), source: expandOp.getSrc(), innerTileSizes: packOp.getMixedTiles(),
855 innerDimsPos: projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
856 Value packedVal = rewriter.create<linalg::PackOp>(
857 location: packOp.getLoc(), args: expandOp.getSrc(), args&: destTensor, args&: projectedInnerDimsPos,
858 args: packOp.getMixedTiles(), args: packOp.getPaddingValue(),
859 /*outerDimsPerm=*/args: SmallVector<int64_t>{});
860
861 Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
862 location: packOp.getLoc(), args: packOp.getDestType(), args&: packedVal, args&: *reassocExpand);
863 rewriter.replaceOp(op: packOp, newValues: newExpandOp);
864
865 return success();
866}
867
868class BubbleUpPackOpThroughReshapeOp final
869 : public OpRewritePattern<linalg::PackOp> {
870public:
871 BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
872 : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
873
874 LogicalResult matchAndRewrite(linalg::PackOp packOp,
875 PatternRewriter &rewriter) const override {
876 Operation *srcOp = packOp.getSource().getDefiningOp();
877 // Currently only support when the pack op is the only user.
878 if (!srcOp || !(srcOp->getNumResults() == 1) ||
879 !srcOp->getResult(idx: 0).hasOneUse()) {
880 return failure();
881 }
882 // Currently only support static inner tile sizes.
883 if (llvm::any_of(Range: packOp.getStaticTiles(), P: ShapedType::isDynamic))
884 return failure();
885
886 // User controlled propagation function.
887 if (!controlFn(&packOp.getSourceMutable()))
888 return failure();
889
890 return TypeSwitch<Operation *, LogicalResult>(srcOp)
891 .Case(caseFn: [&](tensor::CollapseShapeOp op) {
892 return bubbleUpPackOpThroughCollapseShape(collapseOp: op, packOp, rewriter);
893 })
894 .Case(caseFn: [&](tensor::ExpandShapeOp op) {
895 return bubbleUpPackOpThroughExpandShape(expandOp: op, packOp, rewriter);
896 })
897 .Default(defaultFn: [](Operation *) { return failure(); });
898 }
899
900private:
901 ControlPropagationFn controlFn;
902};
903
904/// Push down unpack op through expand shape op when the packed dims can be
905/// projected to the dims after expanding. This is possible when the inner tile
906/// sizes can divide the projected dims.
907///
908/// For example:
909///
910/// %unpack = linalg.unpack %in outer_dims_perm = [0, 1]
911/// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
912/// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
913/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
914/// : tensor<?x256xf32> into tensor<?x256x256xf32>
915///
916/// can be transformed into:
917///
918/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
919/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
920/// %unpack = linalg.unpack %expanded outer_dims_perm = [0, 1, 2]
921/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
922/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
923static LogicalResult pushDownUnPackOpThroughExpandShape(
924 linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
925 PatternRewriter &rewriter, ControlPropagationFn controlFn) {
926 // User controlled propagation function.
927 if (!controlFn(&expandOp.getSrcMutable()))
928 return failure();
929
930 SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
931 ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
932 ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
933
934 auto expandTy = dyn_cast<RankedTensorType>(Val: expandOp.getType());
935 if (!expandTy)
936 return failure();
937 ArrayRef<int64_t> dstShape = expandTy.getShape();
938 SmallVector<ReassociationIndices> reassocIndices =
939 expandOp.getReassociationIndices();
940 // Project inner tile pos to the dim pos after expanding. For example, if dims
941 // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
942 // on dim y.
943 //
944 // Project to inner-most non-unit dims to increase the chance that they can be
945 // divided by the inner tile sizes. This is correct because for [..., x, 1],
946 // unpacking on dim 1 is equivalent to unpacking on dim x.
947 SmallVector<int64_t> projectedInnerDimsPos =
948 projectToInnerMostNonUnitDimsPos(dimsPos: innerDimsPos, reassocIndices, targetShape: dstShape);
949
950 if (!isDimsDivisibleByTileSizes(dimsPos: projectedInnerDimsPos, shape: dstShape,
951 tileSizes: innerTileSizes)) {
952 return failure();
953 }
954 // Expand the outer dims permutation with the associated expanded dims for the
955 // new permutation after pushing. This is because moving a source dim is
956 // equivalent to moving the associated expanded dims together.
957 SmallVector<int64_t> newOuterDimsPerm;
958 for (auto outerPos : outerDimsPerm)
959 llvm::append_range(C&: newOuterDimsPerm, R&: reassocIndices[outerPos]);
960
961 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
962 // First apply the permutation on the reassociations of the outer dims.
963 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
964 // -> [[0], [1, 2]]
965 int64_t nextPos =
966 applyPermutationAndReindexReassoc(reassocIndices&: newReassocIndices, permutation: outerDimsPerm);
967 // Then add direct mapping for the inner tile dims.
968 for (size_t i = 0; i < innerDimsPos.size(); ++i) {
969 newReassocIndices.push_back(Elt: {nextPos});
970 nextPos += 1;
971 }
972
973 RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
974 sourceType: expandTy, innerTileSizes, innerDimsPos: projectedInnerDimsPos, outerDimsPerm: newOuterDimsPerm);
975 auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
976 location: expandOp.getLoc(), args&: newExpandType, args: unPackOp.getSource(),
977 args&: newReassocIndices);
978
979 auto emptyOp = linalg::UnPackOp::createDestinationTensor(
980 b&: rewriter, loc: unPackOp.getLoc(), source: newExpandOp, innerTileSizes: unPackOp.getMixedTiles(),
981 innerDimsPos: projectedInnerDimsPos, outerDimsPerm: newOuterDimsPerm);
982 auto newUnPackOp = rewriter.create<linalg::UnPackOp>(
983 location: unPackOp.getLoc(), args: newExpandOp.getResult(), args&: emptyOp,
984 args&: projectedInnerDimsPos, args: unPackOp.getMixedTiles(), args&: newOuterDimsPerm);
985 rewriter.replaceOp(op: expandOp, newOp: newUnPackOp);
986
987 return success();
988}
989
990class PushDownUnPackOpThroughReshapeOp final
991 : public OpRewritePattern<linalg::UnPackOp> {
992public:
993 PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
994 ControlPropagationFn fun)
995 : OpRewritePattern<linalg::UnPackOp>(context), controlFn(std::move(fun)) {
996 }
997
998 LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
999 PatternRewriter &rewriter) const override {
1000 Value result = unPackOp.getResult();
1001 // Currently only support unpack op with the single user.
1002 if (!result.hasOneUse()) {
1003 return failure();
1004 }
1005 // Currently only support static inner tile sizes.
1006 if (llvm::any_of(Range: unPackOp.getStaticTiles(), P: ShapedType::isDynamic))
1007 return failure();
1008
1009 Operation *consumerOp = *result.user_begin();
1010 return TypeSwitch<Operation *, LogicalResult>(consumerOp)
1011 .Case(caseFn: [&](tensor::ExpandShapeOp op) {
1012 return pushDownUnPackOpThroughExpandShape(unPackOp, expandOp: op, rewriter,
1013 controlFn);
1014 })
1015 .Default(defaultFn: [](Operation *) { return failure(); });
1016 }
1017
1018private:
1019 ControlPropagationFn controlFn;
1020};
1021
1022// TODO: Relax this restriction. We should unpack a generic op also
1023// in the presence of multiple unpack ops as producers.
1024/// Return the unpacked operand, if present, for the current generic op.
1025static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1026 OpOperand *unPackedOperand = nullptr;
1027 for (OpOperand &operand : genericOp->getOpOperands()) {
1028 auto unPackOp = operand.get().getDefiningOp<linalg::UnPackOp>();
1029 if (!unPackOp)
1030 continue;
1031 if (unPackedOperand)
1032 return failure();
1033 unPackedOperand = &operand;
1034 }
1035 if (!unPackedOperand)
1036 return failure();
1037 return unPackedOperand;
1038}
1039
1040/// Push down a linalg.unpack op through a generic op.
1041/// The new generic op works on packed domain; pack ops are created for input
1042/// and output operands. A linalg.unpack op is inserted right after the packed
1043/// generic. E.g.
1044///
1045/// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1046///
1047/// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
1048///
1049/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1050/// %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
1051/// inner_dims_pos = [3] inner_tiles = [32] into %0
1052/// %2 = linalg.generic {indexing_maps = [#map],
1053/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1054/// outs(%1 : tensor<12x56x56x64xf32>) {
1055/// ^bb0(%out : f32):
1056/// linalg.yield %out : f32
1057/// } -> tensor<12x56x56x64xf32>
1058///
1059/// will be converted to
1060///
1061/// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
1062///
1063/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1064/// %1 = linalg.generic {indexing_maps = [#map],
1065/// iterator_types = ["parallel", "parallel", "parallel",
1066/// "parallel", "parallel"]}
1067/// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
1068/// ^bb0(%out : f32):
1069/// linalg.yield %out : f32
1070/// } -> tensor<12x2x56x56x32xf32>
1071/// %2 = linalg.unpack %1 outer_dims_perm = [0, 3, 1, 2]
1072/// inner_dims_pos = [3] inner_tiles = [32] into %0
1073///
1074static FailureOr<std::tuple<GenericOp, Value>>
1075pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1076 ControlPropagationFn controlFn) {
1077 if (genericOp.getNumResults() != 1)
1078 return failure();
1079
1080 if (hasGatherSemantics(genericOp))
1081 return failure();
1082
1083 // Collect the unPacked operand, if present.
1084 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1085 if (failed(Result: maybeUnPackedOperand))
1086 return failure();
1087 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1088
1089 // Extract packing information.
1090 linalg::UnPackOp producerUnPackOp =
1091 unPackedOperand->get().getDefiningOp<linalg::UnPackOp>();
1092 assert(producerUnPackOp && "expect a valid UnPackOp");
1093
1094 if (!controlFn(unPackedOperand))
1095 return failure();
1096
1097 auto packInfo =
1098 getPackingInfoFromOperand(opOperand: unPackedOperand, genericOp, packOrUnPackOp: producerUnPackOp);
1099 if (failed(Result: packInfo))
1100 return failure();
1101
1102 // Rebuild the indexing map for the corresponding init operand.
1103 auto [packedOutOperand, packedOutIndexingMap] =
1104 getOrCreatePackedViewOfOperand(b&: rewriter, loc: genericOp.getLoc(), packInfo: *packInfo,
1105 genericOp, opOperand: genericOp.getDpsInitOperand(i: 0));
1106 auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1107
1108 // Forward the new tensor.empty as a destination if it is one of the following
1109 // situations:
1110 // 1) The dps init operand is a tensor.empty.
1111 // 2) The dps init is a write-only operand, i.e., it is not used in the
1112 // genericOp
1113 Value dest = packedOutOperand;
1114 auto initTensor =
1115 genericOp.getDpsInitOperand(i: 0)->get().getDefiningOp<tensor::EmptyOp>();
1116 if (initTensor || isGenericOutsNotUsed(genericOp)) {
1117 if (destPack)
1118 dest = destPack.getDest();
1119 }
1120
1121 // Pack the genericOp.
1122 // pack(unpack) is foldable in this case. This is because in pushing down the
1123 // unpack, by default we will populate an additional pack op after the unpack.
1124 // This guarantees them to be foldable.
1125 GenericOp newGenericOp =
1126 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, packInfo: *packInfo,
1127 /*isFoldableUnpackPack=*/true);
1128 Value newResult =
1129 newGenericOp.getTiedOpResult(opOperand: newGenericOp.getDpsInitOperand(i: 0));
1130
1131 // If the output is unaffected, no need to unpack.
1132 if (!destPack)
1133 return std::make_tuple(args&: newGenericOp, args&: newResult);
1134
1135 auto mixedTiles = destPack.getMixedTiles();
1136 auto innerDimsPos = destPack.getInnerDimsPos();
1137 auto outerDimsPerm = destPack.getOuterDimsPerm();
1138
1139 // Insert an unPackOp right after the packed generic.
1140 Value unPackOpRes =
1141 rewriter
1142 .create<linalg::UnPackOp>(location: genericOp.getLoc(), args&: newResult,
1143 args: destPack.getSource(), args&: innerDimsPos,
1144 args&: mixedTiles, args&: outerDimsPerm)
1145 .getResult();
1146
1147 return std::make_tuple(args&: newGenericOp, args&: unPackOpRes);
1148}
1149
1150// Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
1151struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1152public:
1153 PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1154 ControlPropagationFn fun)
1155 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1156
1157 LogicalResult matchAndRewrite(GenericOp genericOp,
1158 PatternRewriter &rewriter) const override {
1159 auto genericAndRepl =
1160 pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1161 if (failed(Result: genericAndRepl))
1162 return failure();
1163 rewriter.replaceOp(op: genericOp, newValues: std::get<1>(t&: *genericAndRepl));
1164 return success();
1165 }
1166
1167private:
1168 ControlPropagationFn controlFn;
1169};
1170
1171/// Propagate a linalg.unpack operation through a tensor.pad. The idea is to
1172/// add as many zero padding dimensions in `high` and `low` based on the number
1173/// of point loops.
1174struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1175 PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1176 : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1177
1178 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1179 PatternRewriter &rewriter) const override {
1180 linalg::UnPackOp unpackOp =
1181 padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1182 if (!unpackOp)
1183 return failure();
1184
1185 if (!controlFn(&padOp.getSourceMutable()))
1186 return failure();
1187
1188 Location loc = padOp.getLoc();
1189 // Bail out if one of the padded dimension is a tiled one.
1190 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1191 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1192 llvm::SmallBitVector innerDims(paddedDims.size());
1193 for (int64_t dim : innerDimsPos)
1194 innerDims.flip(Idx: dim);
1195 if (paddedDims.anyCommon(RHS: innerDims))
1196 return failure();
1197
1198 Value paddingVal = padOp.getConstantPaddingValue();
1199 if (!paddingVal)
1200 return failure();
1201
1202 // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1203 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1204 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1205 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1206 if (!outerDimsPerm.empty()) {
1207 applyPermutationToVector<OpFoldResult>(inVec&: lowPad, permutation: outerDimsPerm);
1208 applyPermutationToVector<OpFoldResult>(inVec&: highPad, permutation: outerDimsPerm);
1209 }
1210 // Add zero padding for the point loops.
1211 size_t pointLoopsSize = innerDimsPos.size();
1212 lowPad.append(NumInputs: pointLoopsSize, Elt: rewriter.getIndexAttr(value: 0));
1213 highPad.append(NumInputs: pointLoopsSize, Elt: rewriter.getIndexAttr(value: 0));
1214
1215 auto newPadOp = rewriter.create<tensor::PadOp>(
1216 location: loc, /*result=*/args: Type(), args: unpackOp.getSource(), args&: lowPad, args&: highPad,
1217 args&: paddingVal, args: padOp.getNofold());
1218
1219 // Inject the linalg.unpack right after the packed padOp.
1220 Value outputUnPack = rewriter.create<tensor::EmptyOp>(
1221 location: loc, args: padOp.getResultType().getShape(),
1222 args: padOp.getResultType().getElementType());
1223
1224 Value replacement = rewriter.create<linalg::UnPackOp>(
1225 location: loc, args: newPadOp.getResult(), args&: outputUnPack, args&: innerDimsPos,
1226 args: unpackOp.getMixedTiles(), args&: outerDimsPerm);
1227 rewriter.replaceOp(op: padOp, newValues: replacement);
1228 return success();
1229 }
1230
1231private:
1232 ControlPropagationFn controlFn;
1233};
1234
1235} // namespace
1236
1237void mlir::linalg::populateDataLayoutPropagationPatterns(
1238 RewritePatternSet &patterns,
1239 const ControlPropagationFn &controlPackUnPackPropagation) {
1240 patterns
1241 .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1242 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1243 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1244 arg: patterns.getContext(), args: controlPackUnPackPropagation);
1245}
1246

source code of mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp