1//===- DataLayoutPropagation.cpp -----------------------------------------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Passes.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Linalg/IR/Linalg.h"
13#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14#include "mlir/Dialect/Linalg/Utils/Utils.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/Dialect/Tensor/Utils/Utils.h"
17#include "mlir/Dialect/Utils/IndexingUtils.h"
18#include "mlir/IR/Dominance.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20#include "llvm/ADT/SetOperations.h"
21#include "llvm/ADT/SetVector.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/Debug.h"
24#include <optional>
25
26namespace mlir {
27#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
28#include "mlir/Dialect/Linalg/Passes.h.inc"
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::linalg;
33
34#define DEBUG_TYPE "linalg-data-layout-propagation"
35
36namespace {
37
38static bool hasGatherSemantics(linalg::GenericOp genericOp) {
39 for (Operation &op : genericOp.getBody()->getOperations())
40 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
41 return true;
42 return false;
43}
44
45// The struct contains the infomation about mapping packing information to
46// the iteration domain of Linalg ops.
47struct PackInfo {
48 int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
49 // InnerDimsPos on iteration domain, which follows the order in pack ops.
50 SmallVector<int64_t> tiledDimsPos;
51 // The sizes of tiling data dimensions on iteration domain.
52 llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
53 // The mapping from a dimension of iteration domain to the corresponding inner
54 // tiling dimension on iteration domain.
55 llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
56 // The permutation of outer dims (on domain).
57 SmallVector<int64_t> outerDimsOnDomainPerm;
58};
59
60template <typename OpTy>
61static FailureOr<PackInfo>
62getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
63 OpTy packOrUnPackOp) {
64 static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
65 "applies to only pack or unpack operations");
66 LLVM_DEBUG(
67 { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
68
69 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
70 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
71 SmallVector<utils::IteratorType> iterators =
72 genericOp.getIteratorTypesArray();
73
74 PackInfo packInfo;
75 int64_t origNumDims = indexingMap.getNumDims();
76 SmallVector<AffineExpr> exprs(indexingMap.getResults());
77 ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
78 for (auto [index, innerDimPos, tileSize] :
79 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
80 innerDimsPos, packOrUnPackOp.getMixedTiles())) {
81 auto expr = exprs[innerDimPos];
82 if (!isa<AffineDimExpr>(expr))
83 return failure();
84 int64_t domainDimPos =
85 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
86 if (!isParallelIterator(iterators[domainDimPos]))
87 return failure();
88 packInfo.tiledDimsPos.push_back(domainDimPos);
89 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
90 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
91 LLVM_DEBUG({
92 llvm::dbgs() << "map innerDimPos=" << innerDimPos
93 << " to iteration dimension (d" << domainDimPos << ", d"
94 << packInfo.tileToPointMapping[domainDimPos]
95 << "), which has size=("
96 << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
97 });
98 }
99
100 // Bail out if a tiled dimension is present in a map but not as an affine dim
101 // expression.
102 auto areAllAffineDimExpr = [&](int dim) {
103 for (AffineMap map : indexingMaps) {
104 if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
105 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
106 })) {
107 return false;
108 }
109 }
110 return true;
111 };
112 for (int64_t i : packInfo.tiledDimsPos)
113 if (!areAllAffineDimExpr(i))
114 return failure();
115
116 // Get the outer dims perm on the iteration domain. Start by identifying the
117 // set of domain dims affected by the outer permutation along with the
118 // permuted ordering for those dims. Then the full outer dims permutation can
119 // be constructed by replacing the affected dims with the permuted result in a
120 // numLoops-rank identity. e.g.
121 // outerDimsPerm = [1, 2, 0]
122 // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
123 //
124 // permutedOuterDims = [4, 3, 1]
125 // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
126 //
127 // Non-affine dim expressions must not be permuted by the outer dims
128 // permutation.
129 SmallVector<int64_t> permutedOuterDims;
130 for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
131 auto permutedExpr = indexingMap.getResult(idx: dim);
132 if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
133 permutedOuterDims.push_back(dimExpr.getPosition());
134 continue;
135 }
136
137 // TODO: Allow propagation with transposes on non affine dim expressions,
138 // e.g. d0 + d1 which implies transposing both dims simultaneously while
139 // maintaining the relative position between them.
140 if (static_cast<int64_t>(index) != dim)
141 return failure();
142 }
143 if (!permutedOuterDims.empty()) {
144 int64_t outerDimIndex = 0;
145 llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
146 permutedOuterDims.end());
147 for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
148 packInfo.outerDimsOnDomainPerm.push_back(
149 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
150 : i);
151 LLVM_DEBUG({
152 llvm::dbgs() << "map outer dimsDimsPerm to ";
153 for (auto dim : packInfo.outerDimsOnDomainPerm)
154 llvm::dbgs() << dim << " ";
155 llvm::dbgs() << "\n";
156 });
157 }
158
159 return packInfo;
160}
161
162static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
163 ArrayRef<AffineExpr> exprs) {
164 // Compute `outer_dims_perm`. See example:
165 // current exprs : (d0, d1, d2, d3) -> (d2, d3)
166 // perm : [0, 3, 1, 2]
167 // First map d2, d3 with their position in the array as:
168 // currentPositionTileLoops: dim | pos
169 // d2 | 0
170 // d3 | 1
171 // then scan `perm` in order and get the `outer_dims_perm`
172 // to be used, here it would be [1, 0].
173 assert(!perm.empty() && "expect perm not to be empty");
174 assert(!exprs.empty() && "expect exprs not to be empty");
175 if (exprs.size() == 1)
176 return {};
177 SmallVector<int64_t> outerDimsPerm;
178 DenseMap<int64_t, int64_t> currentPositionTileLoops;
179 for (auto [pos, expr] : llvm::enumerate(exprs)) {
180 // Here we rely on the assumption that the outer dims permutation
181 // when propagating currently requires that non-affine dim expressions
182 // are not permuted, thus allowing the identity assignment below.
183 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
184 currentPositionTileLoops[dimExpr.getPosition()] = pos;
185 else
186 currentPositionTileLoops[pos] = pos;
187 }
188 for (int64_t loopIdx : perm) {
189 if (currentPositionTileLoops.count(loopIdx))
190 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
191 }
192 return outerDimsPerm;
193}
194
195/// Returns a tuple for packed operand and indexing_map with the assumptions:
196/// 1) The generic op is the producer of the pack op.
197/// 2) The generic op has only one result.
198/// If the operand is a scalar or packing dimensions are all irrelevant to the
199/// operand, the operand and the updated indexing map will be returned.
200/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
201///
202/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
203/// #map1 = affine_map<(d0, d1) -> (d0)>
204/// #map2 = affine_map<(d0, d1) -> (d1)>
205/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
206/// iterator_types = ["parallel", "parallel"]}
207/// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
208/// outs(%init : tensor<?x?xf32>) {
209/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
210/// %4 = arith.addf %arg3, %arg4 : f32
211/// linalg.yield %4 : f32
212/// } -> tensor<?x?xf32>
213/// %1 = linalg.pack %0
214/// inner_dims_pos = [0, 1]
215/// inner_tiles = [8, 2]
216/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
217///
218/// Taking the first input operand as an example, the inner tile size of d1 is
219/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
220/// affine_map<(d1, d3)>` will be returned.
221///
222/// %pack = linalg.pack %arg0
223/// inner_dims_pos = [0]
224/// inner_tiles = [8]
225/// into %init : tensor<?xf32> -> tensor<?x8xf32>
226static std::tuple<Value, AffineMap>
227getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
228 GenericOp genericOp, OpOperand *opOperand) {
229 int64_t numOrigLoops = genericOp.getNumLoops();
230 int64_t numInnerLoops = packInfo.getNumTiledLoops();
231 int64_t numLoops = numOrigLoops + numInnerLoops;
232 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
233 llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
234 SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
235
236 // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
237 if (genericOp.isScalar(opOperand) || exprs.empty())
238 return std::make_tuple(opOperand->get(),
239 AffineMap::get(numLoops, 0, exprs, b.getContext()));
240
241 // Step 1. Construct the information of packing data dimensions; append inner
242 // dimensions to the indexing maps for the operand.
243 for (auto [index, expr] : llvm::enumerate(exprs)) {
244 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
245 int64_t dimPos = dimExpr.getPosition();
246 domainDimToOperandDim[dimPos] = index;
247 continue;
248 }
249 }
250 SmallVector<int64_t> innerDimsPos;
251 SmallVector<OpFoldResult> innerTileSizes;
252 for (auto dimPos : packInfo.tiledDimsPos) {
253 if (!domainDimToOperandDim.count(dimPos))
254 continue;
255 int64_t index = domainDimToOperandDim[dimPos];
256 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
257 innerDimsPos.push_back(index);
258 exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
259 }
260
261 // Step 2. Handle outer dim permutations.
262 SmallVector<int64_t> outerDimsPerm;
263 if (!packInfo.outerDimsOnDomainPerm.empty()) {
264 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
265
266 // Step 2.1: Fold transpose into the linalg.generic.
267 SmallVector<int64_t> inversedOuterPerm =
268 invertPermutationVector(packInfo.outerDimsOnDomainPerm);
269 for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
270 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
271 int64_t dimPos = dimExpr.getPosition();
272 exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
273 continue;
274 }
275 assert(isa<AffineConstantExpr>(exprs[i]) &&
276 "Attempted to permute non-constant and non-affine dim expression");
277 }
278 // Step 2.2: Undo the transposition on `exprs` and propagate the
279 // transposition on the pack using outerDimsPerm.
280 if (!outerDimsPerm.empty()) {
281 SmallVector<AffineExpr> auxVec = exprs;
282 for (const auto &en : enumerate(outerDimsPerm))
283 auxVec[en.index()] = exprs[en.value()];
284 exprs = auxVec;
285 }
286 }
287 auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
288
289 // The operand does not have dimensions that relates to pack op.
290 if (innerDimsPos.empty() && outerDimsPerm.empty())
291 return std::make_tuple(opOperand->get(), indexingMap);
292
293 auto empty = linalg::PackOp::createDestinationTensor(
294 b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
295 auto packedOperand = b.create<linalg::PackOp>(
296 loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
297 /*padding=*/std::nullopt, outerDimsPerm);
298 return std::make_tuple(packedOperand, indexingMap);
299}
300
301/// This function is a helper subroutine to pack a genericOp and return it. It
302/// will create a new generic op with the packed operand and the packed output
303/// according to packInfo when we attempt to push down unpack or bubble up pack
304/// around it. Implicitly this will only work when a packInfo can be obtained.
305/// This make sure that we are only using this function on parallel permuted
306/// dimensions.
307static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
308 Value dest, AffineMap packedOutIndexingMap,
309 const PackInfo &packInfo,
310 bool isFoldableUnpackPack) {
311 Location loc = genericOp.getLoc();
312 SmallVector<Value> inputOperands;
313 SmallVector<Value> inputOperandsFromUnpackedSource;
314 SmallVector<AffineMap> indexingMaps;
315 auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
316 return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
317 packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
318 llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
319 };
320 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
321 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
322 rewriter, loc, packInfo, genericOp, inputOperand);
323 auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
324 auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
325 if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
326 inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
327 } else {
328 inputOperandsFromUnpackedSource.push_back(packedOperand);
329 }
330 inputOperands.push_back(packedOperand);
331 indexingMaps.push_back(packedIndexingMap);
332 }
333
334 // If the unpack->pack sequences can be folded, replace use the sources of
335 // the unpack ops in any unpack->pack chains on the generic op operands.
336 if (isFoldableUnpackPack) {
337 inputOperands = inputOperandsFromUnpackedSource;
338 if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) {
339 auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
340 if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
341 dest = destUnPack.getSource();
342 }
343 }
344 }
345
346 int64_t numInnerLoops = packInfo.getNumTiledLoops();
347 SmallVector<utils::IteratorType> iterTypes =
348 genericOp.getIteratorTypesArray();
349 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
350
351 indexingMaps.push_back(packedOutIndexingMap);
352
353 auto newGenericOp = rewriter.create<linalg::GenericOp>(
354 loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
355 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
356 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
357 newGenericOp.getRegion().begin());
358 return newGenericOp;
359}
360
361/// Bubbles up linalg.pack op through a producer generic op. This
362/// swap pack(generic) to generic(pack). The new generic op works on packed
363/// domain; pack ops are created for input and output operands. E.g.,
364///
365/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
366/// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
367/// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
368/// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
369/// %3 = linalg.generic {indexing_maps = [#map0, #map0],
370/// iterator_types = ["parallel", "parallel"]}
371/// ins(%arg0 : tensor<?x?xf32>)
372/// outs(%2 : tensor<?x?xf32>) {
373/// ^bb0(%arg3: f32, %arg4: f32):
374/// %4 = arith.addf %arg3, %arg3 : f32
375/// linalg.yield %4 : f32
376/// } -> tensor<?x?xf32>
377/// %4 = linalg.pack %3
378/// inner_dims_pos = [0, 1]
379/// inner_tiles = [8, 2]
380/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
381///
382/// will be converted to
383///
384/// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
385/// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
386/// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
387/// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
388/// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
389/// %0 = affine.apply #map()[%dim]
390/// %1 = affine.apply #map1()[%dim_0]
391/// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
392/// %pack = linalg.pack %arg0
393/// inner_dims_pos = [0, 1]
394/// inner_tiles = [8, 2]
395/// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
396/// %3 = linalg.generic {indexing_maps = [#map2, #map2],
397/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
398/// ins(%pack : tensor<?x?x8x2xf32>)
399/// outs(%arg1 : tensor<?x?x8x2xf32>) {
400/// ^bb0(%in: f32, %out: f32):
401/// %4 = arith.addf %in, %in : f32
402/// linalg.yield %4 : f32
403/// } -> tensor<?x?x8x2xf32>
404static FailureOr<GenericOp>
405bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
406 const ControlPropagationFn &controlFn) {
407 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
408 if (!genericOp)
409 return failure();
410
411 // User controlled propagation function.
412 if (!controlFn(&packOp.getSourceMutable()))
413 return failure();
414
415 // TODO: Enable propagation in the presence of linalg.index and
416 // tensor.extract, likely as a separate pattern as the pack information and
417 // propagation decision needs to be inferred from the region of the generic.
418 if (hasGatherSemantics(genericOp))
419 return failure();
420
421 // TODO: Relax the restriction. We are able to bubble up the pack op through
422 // multi-result generic op. It just needs more work.
423 if (genericOp.getNumResults() != 1)
424 return failure();
425
426 // Bail-out if the result of the generic has multiple uses, as bubbling up
427 // creates recomputation if the generic has multiple users.
428 // TODO: Enable the case where every use is an identical pack op as no
429 // recomputation is needed in that case.
430 if (!genericOp->getResult(0).hasOneUse())
431 return failure();
432
433 // TODO: Add an option for allowing padding values. It could introduce
434 // undefined behavior if we unconditionally propagate pack op through all
435 // the ops. E.g., if the padding value is zero and there are division ops in
436 // a generic op. Some values of padding area could be NaN (0/0).
437 if (packOp.getPaddingValue())
438 return failure();
439
440 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
441 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
442 if (failed(packInfo))
443 return failure();
444
445 // We want to move the pack not the generic.
446 OpBuilder::InsertionGuard guard(rewriter);
447 rewriter.setInsertionPoint(genericOp);
448
449 // We need to handle two cases:
450 // 1) The linalg.pack destination is a tensor.empty. If this is the case, we
451 // create a new tensor.empty to avoid breaking dominance, as we are moving the
452 // linalg.pack above the linalg.generic.
453 // 2) The destination is not a tensor.empty. In this case we can replace only
454 // if the destination of the linalg.pack dominates the linalg.generic.
455 Value packOpDest = packOp.getDest();
456 if (!packOpDest.hasOneUse())
457 return failure();
458 if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
459 packOpDest = rewriter.create<tensor::EmptyOp>(
460 genericOp->getLoc(), emptyOp.getMixedSizes(),
461 emptyOp.getType().getElementType());
462 } else {
463 DominanceInfo dom(genericOp);
464 if (!dom.properlyDominates(packOpDest, genericOp))
465 return failure();
466 }
467
468 // Rebuild the indexing map for the corresponding init operand.
469 auto [packedOutOperand, packedOutIndexingMap] =
470 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
471 genericOp, opOperand);
472
473 // If the dps init operand of the generic is a tensor.empty forward the pack
474 // op destination.
475 Value dest = packedOutOperand;
476 if (auto initTensor = genericOp.getDpsInitOperand(0)
477 ->get()
478 .getDefiningOp<tensor::EmptyOp>()) {
479 dest = packOpDest;
480 }
481 // pack(unpack) isn't naively foldable because the unpack op can be from
482 // an arbitrary domain so we need to keep both.
483 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
484 *packInfo, /*isFoldableUnpackPack=*/false);
485}
486
487/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
488struct BubbleUpPackOpThroughGenericOpPattern
489 : public OpRewritePattern<linalg::PackOp> {
490public:
491 BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
492 ControlPropagationFn fun)
493 : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
494
495 LogicalResult matchAndRewrite(linalg::PackOp packOp,
496 PatternRewriter &rewriter) const override {
497 auto genericOp =
498 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
499 if (failed(genericOp))
500 return failure();
501 rewriter.replaceOp(packOp, genericOp->getResults());
502 return success();
503 }
504
505private:
506 ControlPropagationFn controlFn;
507};
508
509/// Propagate a linalg.pack operation up through a tensor.pad. The idea is to
510/// add as many zero padding dimensions in `high` and `low` based on the number
511/// of point loops.
512class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
513public:
514 BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
515 : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
516
517 LogicalResult matchAndRewrite(linalg::PackOp packOp,
518 PatternRewriter &rewriter) const override {
519 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
520 if (!padOp)
521 return failure();
522
523 // User controlled propagation function.
524 if (!controlFn(&packOp.getSourceMutable()))
525 return failure();
526
527 // TODO: Enable padding when the padding values are the same.
528 if (packOp.getPaddingValue())
529 return failure();
530
531 // Fail for non-constant padding values. The body of the pad could
532 // depend on the padding indices and/or properties of the padded
533 // tensor so for now we fail.
534 // TODO: Support non-constant padding values.
535 Value paddingVal = padOp.getConstantPaddingValue();
536 if (!paddingVal)
537 return failure();
538
539 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
540 return failure();
541
542 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
543
544 // Bail out if one of the padded dimension is a tiled one.
545 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
546 llvm::SmallBitVector innerDims(paddedDims.size());
547 for (int64_t dim : innerDimsPos)
548 innerDims.flip(dim);
549 if (paddedDims.anyCommon(RHS: innerDims))
550 return failure();
551
552 Location loc = padOp->getLoc();
553 OpBuilder::InsertionGuard guard(rewriter);
554 rewriter.setInsertionPoint(padOp);
555
556 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
557 SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
558 auto empty = linalg::PackOp::createDestinationTensor(
559 rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
560 outerDimsPerm);
561 auto sourcePack = rewriter.create<linalg::PackOp>(
562 loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
563 /*padding=*/std::nullopt, outerDimsPerm);
564
565 // If we have `outer_dims_perms` we need to adjust the padded dimensions.
566 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
567 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
568 if (!outerDimsPerm.empty()) {
569 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
570 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
571 }
572 // The tiled dimensions were verified to be unpadded above, so here we
573 // just append 0 for the inner tile dimensions.
574 size_t pointLoopsSize = innerDimsPos.size();
575 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
576 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
577
578 auto newPadOp = rewriter.create<tensor::PadOp>(
579 loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal,
580 padOp.getNofold());
581
582 // If the pad has more than one user, create an unpack on the new pad to
583 // replace the other uses.
584 if (!padOp->hasOneUse()) {
585 auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
586 rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
587 Value unpackedPad = rewriter.create<linalg::UnPackOp>(
588 loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
589 rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
590 }
591
592 // Replace the pack with the new pad.
593 rewriter.replaceOp(packOp, newPadOp.getResult());
594
595 return success();
596 }
597
598private:
599 ControlPropagationFn controlFn;
600};
601
602/// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
603///
604/// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
605/// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
606/// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
607/// non-unit projected dims in pos [2, 3] is 2.
608///
609/// If all candidates in a reassociation are unit dims, it chooses the
610/// inner-most dim pos.
611static SmallVector<int64_t>
612projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
613 ArrayRef<ReassociationIndices> reassocIndices,
614 ArrayRef<int64_t> targetShape) {
615 SmallVector<int64_t> projectedDimsPos;
616 for (auto pos : dimsPos) {
617 // In the case all dims are unit, this will return the inner-most one.
618 int64_t projectedPos = reassocIndices[pos].back();
619 for (auto i : llvm::reverse(reassocIndices[pos])) {
620 int64_t dim = targetShape[i];
621 if (dim > 1 || ShapedType::isDynamic(dim)) {
622 projectedPos = i;
623 break;
624 }
625 }
626 projectedDimsPos.push_back(projectedPos);
627 }
628 return projectedDimsPos;
629}
630
631/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
632static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
633 ArrayRef<int64_t> shape,
634 ArrayRef<int64_t> tileSizes) {
635 for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
636 int64_t dim = shape[pos];
637 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
638 return false;
639 }
640 return true;
641}
642
643/// Permutate the reassociation indices and reindex them in the sequence order.
644/// Returns the next dim pos in the sequence.
645///
646/// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
647/// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
648/// [[0], [1, 2]].
649static int64_t applyPermutationAndReindexReassoc(
650 SmallVector<ReassociationIndices> &reassocIndices,
651 ArrayRef<int64_t> permutation) {
652 if (!permutation.empty())
653 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
654 int64_t nextPos = 0;
655 for (ReassociationIndices &indices : reassocIndices) {
656 for (auto &index : indices) {
657 index = nextPos;
658 nextPos += 1;
659 }
660 }
661 return nextPos;
662}
663
664/// Bubble up pack op through collapse shape op when the packed dims can be
665/// projected to the dims before collapsing. This is possible when the inner
666/// tile sizes can divide the projected dims.
667///
668/// For example:
669///
670/// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
671/// : tensor<?x16x4xf32> into tensor<?x4xf32>
672/// %pack = linalg.pack %collapsed outer_dims_perm = [0, 1]
673/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
674/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
675///
676/// can be transformed into:
677///
678/// %pack = linalg.pack %in outer_dims_perm = [1, 2]
679/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
680/// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
681/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
682/// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
683static LogicalResult
684bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
685 linalg::PackOp packOp,
686 PatternRewriter &rewriter) {
687 SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
688 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
689 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
690
691 ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
692 SmallVector<ReassociationIndices> reassocIndices =
693 collapseOp.getReassociationIndices();
694 // Project inner tile pos to the dim pos before collapsing. For example, if
695 // dims [x, y] is collapsed into [z], packing on dim z can be projected back
696 // to pack on dim y.
697 //
698 // Project to inner-most non-unit dims to increase the chance that they can be
699 // divided by the inner tile sizes. This is correct because for [..., x, 1],
700 // packing on dim 1 is equivalent to packing on dim x.
701 SmallVector<int64_t> projectedInnerDimsPos =
702 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
703
704 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
705 innerTileSizes)) {
706 return failure();
707 }
708 // Expand the outer dims permutation with the associated source dims for the
709 // new permutation after bubbling. This is because moving a collapsed dim is
710 // equivalent to moving the associated source dims together.
711 SmallVector<int64_t> newOuterDimsPerm;
712 for (auto outerPos : outerDimsPerm)
713 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
714
715 auto emptyOp = linalg::PackOp::createDestinationTensor(
716 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
717 projectedInnerDimsPos, newOuterDimsPerm);
718 auto newPackOp = rewriter.create<linalg::PackOp>(
719 packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
720 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
721
722 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
723 // First apply the permutation on the reassociations of the outer dims.
724 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
725 // -> [[0], [1, 2]]
726 int64_t nextPos =
727 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
728 // Then add direct mapping for the inner tile dims.
729 for (size_t i = 0; i < innerDimsPos.size(); ++i) {
730 newReassocIndices.push_back({nextPos});
731 nextPos += 1;
732 }
733
734 auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
735 collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
736 rewriter.replaceOp(packOp, newCollapseOp);
737
738 return success();
739}
740
741/// Project dimsPos to their collapsed positions in the reassocIndices.
742///
743/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
744/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
745/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
746/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
747static SmallVector<int64_t>
748projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
749 ArrayRef<ReassociationIndices> reassocIndices) {
750 SmallVector<int64_t> projectedPos;
751
752 // Map each dimension to the position of corresponding reassociation index.
753 for (auto pos : dimsPos) {
754 for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
755 // If the dimension is present in the current indices group, the group
756 // position within the reassociation map is the desired projected
757 // dimension position.
758 if (llvm::is_contained(indices, pos)) {
759 projectedPos.push_back(idx);
760 break;
761 }
762 }
763 }
764 assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
765
766 return projectedPos;
767}
768
769/// Bubble up pack op through expand shape op.
770///
771/// For example:
772///
773/// %expand = tensor.expand_shape %in [[0], [1, 2]]
774/// : tensor<?x64xf32> into tensor<?x4x16xf32>
775/// %pack = linalg.pack %expand outer_dims_perm = [0, 1]
776/// inner_dims_pos = [2] inner_tiles = [8] into %empty
777/// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
778///
779/// can be transformed into:
780///
781/// %pack = linalg.pack %in outer_dims_perm = [1, 2]
782/// inner_dims_pos = [1] inner_tiles = [8] into %empty
783/// : tensor<?x64xf32> -> tensor<?x8x8xf32>
784/// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
785/// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
786static LogicalResult
787bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
788 linalg::PackOp packOp,
789 PatternRewriter &rewriter) {
790 // Outer dimensions permutation is not supported currently.
791 // TODO: Handle outer_dims_perm variants.
792 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
793 if (!outerDimsPerm.empty() && !isIdentityPermutation(permutation: outerDimsPerm)) {
794 return rewriter.notifyMatchFailure(packOp,
795 "non-identity outer dims perm NYI");
796 }
797
798 // Validate dimensions' relations between shape expansion and packing.
799 SmallVector<ReassociationIndices, 4> reassoc =
800 expandOp.getReassociationIndices();
801 ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
802 llvm::SetVector<int64_t> packDimsPos(llvm::from_range, packInnerDims);
803
804 for (auto [idx, indices] : llvm::enumerate(reassoc)) {
805 // For each expand_shape reassociation, figure out which dimensions get
806 // packed if any.
807 llvm::SetVector<int64_t> expandDimPos(llvm::from_range, indices);
808 llvm::SetVector<int64_t> packedDims =
809 llvm::set_intersection(packDimsPos, expandDimPos);
810
811 // The expanded dimension is not packed so, it does not affect moving pack
812 // before shape expansion - simply continue.
813 if (packedDims.empty())
814 continue;
815 // Shape expansion cannot be propagated when multiple expanded dimension are
816 // packed - in this case operation reordering would affect final element
817 // positions and/or shapes can no longer be projected.
818 if (packedDims.size() != 1)
819 return rewriter.notifyMatchFailure(
820 packOp, "only one of the expanded dimensions can be packed");
821 // Only the inner-most expanded dimension should be packed. Otherwise,
822 // elements order will be affected after operation reordering.
823 if (packedDims.front() != indices.back())
824 return rewriter.notifyMatchFailure(
825 packOp, "can only pack the inner-most expanded dimension");
826 }
827
828 // Project pack.inner_dims_pos to positions before shape expansion.
829 SmallVector<int64_t> projectedInnerDimsPos =
830 projectDimsPosIntoReassocPos(packInnerDims, reassoc);
831
832 // Project the shape expansion to new packed shape.
833 // The pack.outer_dims_perm is restricted to identity so, the permutation can
834 // be omitted for simplicity.
835 // TODO: Account for outer dimensions permutation.
836 //
837 // If reassociation is not possible, then reordering cannot happen.
838 // This can be caused by pack padding affecting previously expanded
839 // dimensions or packing extending dimensions.
840 RankedTensorType newPackType = linalg::PackOp::inferPackedType(
841 expandOp.getSrcType(), packOp.getStaticInnerTiles(),
842 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
843 auto reassocExpand =
844 getReassociationIndicesForReshape(newPackType, packOp.getDestType());
845 if (!reassocExpand)
846 return rewriter.notifyMatchFailure(
847 packOp, "could not reassociate dims after bubbling up");
848
849 Value destTensor = linalg::PackOp::createDestinationTensor(
850 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
851 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
852 Value packedVal = rewriter.create<linalg::PackOp>(
853 packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
854 packOp.getMixedTiles(), packOp.getPaddingValue(),
855 /*outerDimsPerm=*/SmallVector<int64_t>{});
856
857 Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
858 packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
859 rewriter.replaceOp(packOp, newExpandOp);
860
861 return success();
862}
863
864class BubbleUpPackOpThroughReshapeOp final
865 : public OpRewritePattern<linalg::PackOp> {
866public:
867 BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
868 : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
869
870 LogicalResult matchAndRewrite(linalg::PackOp packOp,
871 PatternRewriter &rewriter) const override {
872 Operation *srcOp = packOp.getSource().getDefiningOp();
873 // Currently only support when the pack op is the only user.
874 if (!srcOp || !(srcOp->getNumResults() == 1) ||
875 !srcOp->getResult(idx: 0).hasOneUse()) {
876 return failure();
877 }
878 // Currently only support static inner tile sizes.
879 if (llvm::any_of(packOp.getStaticTiles(), ShapedType::isDynamic))
880 return failure();
881
882 // User controlled propagation function.
883 if (!controlFn(&packOp.getSourceMutable()))
884 return failure();
885
886 return TypeSwitch<Operation *, LogicalResult>(srcOp)
887 .Case([&](tensor::CollapseShapeOp op) {
888 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
889 })
890 .Case([&](tensor::ExpandShapeOp op) {
891 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
892 })
893 .Default([](Operation *) { return failure(); });
894 }
895
896private:
897 ControlPropagationFn controlFn;
898};
899
900/// Push down unpack op through expand shape op when the packed dims can be
901/// projected to the dims after expanding. This is possible when the inner tile
902/// sizes can divide the projected dims.
903///
904/// For example:
905///
906/// %unpack = linalg.unpack %in outer_dims_perm = [0, 1]
907/// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
908/// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
909/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
910/// : tensor<?x256xf32> into tensor<?x256x256xf32>
911///
912/// can be transformed into:
913///
914/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
915/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
916/// %unpack = linalg.unpack %expanded outer_dims_perm = [0, 1, 2]
917/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
918/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
919static LogicalResult pushDownUnPackOpThroughExpandShape(
920 linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
921 PatternRewriter &rewriter, ControlPropagationFn controlFn) {
922 // User controlled propagation function.
923 if (!controlFn(&expandOp.getSrcMutable()))
924 return failure();
925
926 SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
927 ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
928 ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
929
930 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
931 if (!expandTy)
932 return failure();
933 ArrayRef<int64_t> dstShape = expandTy.getShape();
934 SmallVector<ReassociationIndices> reassocIndices =
935 expandOp.getReassociationIndices();
936 // Project inner tile pos to the dim pos after expanding. For example, if dims
937 // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
938 // on dim y.
939 //
940 // Project to inner-most non-unit dims to increase the chance that they can be
941 // divided by the inner tile sizes. This is correct because for [..., x, 1],
942 // unpacking on dim 1 is equivalent to unpacking on dim x.
943 SmallVector<int64_t> projectedInnerDimsPos =
944 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
945
946 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
947 innerTileSizes)) {
948 return failure();
949 }
950 // Expand the outer dims permutation with the associated expanded dims for the
951 // new permutation after pushing. This is because moving a source dim is
952 // equivalent to moving the associated expanded dims together.
953 SmallVector<int64_t> newOuterDimsPerm;
954 for (auto outerPos : outerDimsPerm)
955 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
956
957 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
958 // First apply the permutation on the reassociations of the outer dims.
959 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
960 // -> [[0], [1, 2]]
961 int64_t nextPos =
962 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
963 // Then add direct mapping for the inner tile dims.
964 for (size_t i = 0; i < innerDimsPos.size(); ++i) {
965 newReassocIndices.push_back({nextPos});
966 nextPos += 1;
967 }
968
969 RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
970 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
971 auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
972 expandOp.getLoc(), newExpandType, unPackOp.getSource(),
973 newReassocIndices);
974
975 auto emptyOp = linalg::UnPackOp::createDestinationTensor(
976 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
977 projectedInnerDimsPos, newOuterDimsPerm);
978 auto newUnPackOp = rewriter.create<linalg::UnPackOp>(
979 unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
980 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
981 rewriter.replaceOp(expandOp, newUnPackOp);
982
983 return success();
984}
985
986class PushDownUnPackOpThroughReshapeOp final
987 : public OpRewritePattern<linalg::UnPackOp> {
988public:
989 PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
990 ControlPropagationFn fun)
991 : OpRewritePattern<linalg::UnPackOp>(context), controlFn(std::move(fun)) {
992 }
993
994 LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
995 PatternRewriter &rewriter) const override {
996 Value result = unPackOp.getResult();
997 // Currently only support unpack op with the single user.
998 if (!result.hasOneUse()) {
999 return failure();
1000 }
1001 // Currently only support static inner tile sizes.
1002 if (llvm::any_of(unPackOp.getStaticTiles(), ShapedType::isDynamic))
1003 return failure();
1004
1005 Operation *consumerOp = *result.user_begin();
1006 return TypeSwitch<Operation *, LogicalResult>(consumerOp)
1007 .Case([&](tensor::ExpandShapeOp op) {
1008 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
1009 controlFn);
1010 })
1011 .Default([](Operation *) { return failure(); });
1012 }
1013
1014private:
1015 ControlPropagationFn controlFn;
1016};
1017
1018// TODO: Relax this restriction. We should unpack a generic op also
1019// in the presence of multiple unpack ops as producers.
1020/// Return the unpacked operand, if present, for the current generic op.
1021static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1022 OpOperand *unPackedOperand = nullptr;
1023 for (OpOperand &operand : genericOp->getOpOperands()) {
1024 auto unPackOp = operand.get().getDefiningOp<linalg::UnPackOp>();
1025 if (!unPackOp)
1026 continue;
1027 if (unPackedOperand)
1028 return failure();
1029 unPackedOperand = &operand;
1030 }
1031 if (!unPackedOperand)
1032 return failure();
1033 return unPackedOperand;
1034}
1035
1036/// Push down a linalg.unpack op through a generic op.
1037/// The new generic op works on packed domain; pack ops are created for input
1038/// and output operands. A linalg.unpack op is inserted right after the packed
1039/// generic. E.g.
1040///
1041/// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1042///
1043/// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
1044///
1045/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1046/// %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
1047/// inner_dims_pos = [3] inner_tiles = [32] into %0
1048/// %2 = linalg.generic {indexing_maps = [#map],
1049/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1050/// outs(%1 : tensor<12x56x56x64xf32>) {
1051/// ^bb0(%out : f32):
1052/// linalg.yield %out : f32
1053/// } -> tensor<12x56x56x64xf32>
1054///
1055/// will be converted to
1056///
1057/// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
1058///
1059/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1060/// %1 = linalg.generic {indexing_maps = [#map],
1061/// iterator_types = ["parallel", "parallel", "parallel",
1062/// "parallel", "parallel"]}
1063/// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
1064/// ^bb0(%out : f32):
1065/// linalg.yield %out : f32
1066/// } -> tensor<12x2x56x56x32xf32>
1067/// %2 = linalg.unpack %1 outer_dims_perm = [0, 3, 1, 2]
1068/// inner_dims_pos = [3] inner_tiles = [32] into %0
1069///
1070static FailureOr<std::tuple<GenericOp, Value>>
1071pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1072 ControlPropagationFn controlFn) {
1073 if (genericOp.getNumResults() != 1)
1074 return failure();
1075
1076 if (hasGatherSemantics(genericOp))
1077 return failure();
1078
1079 // Collect the unPacked operand, if present.
1080 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1081 if (failed(maybeUnPackedOperand))
1082 return failure();
1083 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1084
1085 // Extract packing information.
1086 linalg::UnPackOp producerUnPackOp =
1087 unPackedOperand->get().getDefiningOp<linalg::UnPackOp>();
1088 assert(producerUnPackOp && "expect a valid UnPackOp");
1089
1090 if (!controlFn(unPackedOperand))
1091 return failure();
1092
1093 auto packInfo =
1094 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1095 if (failed(packInfo))
1096 return failure();
1097
1098 // Rebuild the indexing map for the corresponding init operand.
1099 auto [packedOutOperand, packedOutIndexingMap] =
1100 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1101 genericOp, genericOp.getDpsInitOperand(0));
1102 auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1103
1104 // If the dps init operand of the generic is a tensor.empty, do not pack it
1105 // and forward the new tensor.empty as a destination.
1106 Value dest = packedOutOperand;
1107 if (auto initTensor = genericOp.getDpsInitOperand(0)
1108 ->get()
1109 .getDefiningOp<tensor::EmptyOp>()) {
1110 if (destPack)
1111 dest = destPack.getDest();
1112 }
1113
1114 // Pack the genericOp.
1115 // pack(unpack) is foldable in this case. This is because in pushing down the
1116 // unpack, by default we will populate an additional pack op after the unpack.
1117 // This guarantees them to be foldable.
1118 GenericOp newGenericOp =
1119 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1120 /*isFoldableUnpackPack=*/true);
1121 Value newResult =
1122 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1123
1124 // If the output is unaffected, no need to unpack.
1125 if (!destPack)
1126 return std::make_tuple(newGenericOp, newResult);
1127
1128 auto mixedTiles = destPack.getMixedTiles();
1129 auto innerDimsPos = destPack.getInnerDimsPos();
1130 auto outerDimsPerm = destPack.getOuterDimsPerm();
1131
1132 // Insert an unPackOp right after the packed generic.
1133 Value unPackOpRes =
1134 rewriter
1135 .create<linalg::UnPackOp>(genericOp.getLoc(), newResult,
1136 destPack.getSource(), innerDimsPos,
1137 mixedTiles, outerDimsPerm)
1138 .getResult();
1139
1140 return std::make_tuple(newGenericOp, unPackOpRes);
1141}
1142
1143// Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
1144struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1145public:
1146 PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1147 ControlPropagationFn fun)
1148 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1149
1150 LogicalResult matchAndRewrite(GenericOp genericOp,
1151 PatternRewriter &rewriter) const override {
1152 auto genericAndRepl =
1153 pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1154 if (failed(genericAndRepl))
1155 return failure();
1156 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1157 return success();
1158 }
1159
1160private:
1161 ControlPropagationFn controlFn;
1162};
1163
1164/// Propagate a linalg.unpack operation through a tensor.pad. The idea is to
1165/// add as many zero padding dimensions in `high` and `low` based on the number
1166/// of point loops.
1167struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1168 PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1169 : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1170
1171 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1172 PatternRewriter &rewriter) const override {
1173 linalg::UnPackOp unpackOp =
1174 padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1175 if (!unpackOp)
1176 return failure();
1177
1178 if (!controlFn(&padOp.getSourceMutable()))
1179 return failure();
1180
1181 Location loc = padOp.getLoc();
1182 // Bail out if one of the padded dimension is a tiled one.
1183 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1184 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1185 llvm::SmallBitVector innerDims(paddedDims.size());
1186 for (int64_t dim : innerDimsPos)
1187 innerDims.flip(dim);
1188 if (paddedDims.anyCommon(RHS: innerDims))
1189 return failure();
1190
1191 Value paddingVal = padOp.getConstantPaddingValue();
1192 if (!paddingVal)
1193 return failure();
1194
1195 // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1196 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1197 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1198 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1199 if (!outerDimsPerm.empty()) {
1200 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1201 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1202 }
1203 // Add zero padding for the point loops.
1204 size_t pointLoopsSize = innerDimsPos.size();
1205 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1206 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1207
1208 auto newPadOp = rewriter.create<tensor::PadOp>(
1209 loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
1210 paddingVal, padOp.getNofold());
1211
1212 // Inject the linalg.unpack right after the packed padOp.
1213 Value outputUnPack = rewriter.create<tensor::EmptyOp>(
1214 loc, padOp.getResultType().getShape(),
1215 padOp.getResultType().getElementType());
1216
1217 Value replacement = rewriter.create<linalg::UnPackOp>(
1218 loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1219 unpackOp.getMixedTiles(), outerDimsPerm);
1220 rewriter.replaceOp(padOp, replacement);
1221 return success();
1222 }
1223
1224private:
1225 ControlPropagationFn controlFn;
1226};
1227
1228} // namespace
1229
1230void mlir::linalg::populateDataLayoutPropagationPatterns(
1231 RewritePatternSet &patterns,
1232 const ControlPropagationFn &controlPackUnPackPropagation) {
1233 patterns
1234 .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1235 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1236 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1237 arg: patterns.getContext(), args: controlPackUnPackPropagation);
1238}
1239

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp