1//===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic and helpers to expose Linalg transforms as rewrite
10// patterns.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Utils/Utils.h"
20#include "mlir/Dialect/SCF/Transforms/Transforms.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
23#include "mlir/Dialect/Tensor/Utils/Utils.h"
24#include "mlir/Dialect/Utils/IndexingUtils.h"
25#include "mlir/Dialect/Utils/StaticValueUtils.h"
26#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
27#include "mlir/Dialect/Vector/IR/VectorOps.h"
28#include "mlir/IR/AffineExpr.h"
29#include "mlir/Support/LLVM.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/Support/InterleavedRange.h"
33#include "llvm/Support/raw_ostream.h"
34#include <utility>
35
36#define DEBUG_TYPE "linalg-transforms"
37
38using namespace mlir;
39using namespace mlir::linalg;
40
41#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
42#define DBGSNL() (llvm::dbgs() << "\n")
43
44//===----------------------------------------------------------------------===//
45// Transformations exposed as functional-style API calls.
46//===----------------------------------------------------------------------===//
47
48//===----------------------------------------------------------------------===//
49// peelLoop transformation.
50//===----------------------------------------------------------------------===//
51
52/// Try to peel and canonicalize loop `op` and return the new result.
53/// Also applies affine_min/max bounds simplification on the fly where relevant.
54// TODO: Add support for scf.parallel and affine.for loops.
55SmallVector<Value> mlir::linalg::peelLoop(RewriterBase &rewriter,
56 Operation *op) {
57 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
58 .Case<scf::ForOp>(caseFn: [&](scf::ForOp forOp) {
59 scf::ForOp partialIteration;
60 if (succeeded(Result: scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
61 partialIteration)))
62 return partialIteration->getResults();
63 assert(!partialIteration && "expected that loop was not peeled");
64 return forOp->getResults();
65 })
66 .Default(defaultFn: [&](Operation *op) { return op->getResults(); });
67}
68
69/// Peel 'loops' and applies affine_min/max bounds simplification on the fly
70/// where relevant.
71void mlir::linalg::peelLoops(RewriterBase &rewriter,
72 ArrayRef<scf::ForOp> loops) {
73 for (auto loopOp : loops)
74 peelLoop(rewriter, op: loopOp);
75}
76
77//===----------------------------------------------------------------------===//
78// pack transformation.
79//===----------------------------------------------------------------------===//
80
81#ifndef NDEBUG
82/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
83static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
84 bool found = false;
85 for (AffineExpr e : map.getResults()) {
86 if (!e.isFunctionOfDim(dim))
87 continue;
88 if (found)
89 return false;
90 found = true;
91 }
92 return true;
93}
94
95static std::string stringifyReassocIndices(ReassociationIndicesRef ri) {
96 return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/"");
97}
98#endif // NDEBUG
99
100/// Return the index of the first result of `map` that is a function of
101/// AffineDimExpr(dim), std::nullopt otherwise.
102static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
103 int64_t dim) {
104 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
105 AffineExpr expr = map.getResult(idx: i);
106 if (!expr.isFunctionOfDim(position: dim))
107 continue;
108 return i;
109 }
110 return std::nullopt;
111}
112
113/// Perform one step of packing of a LinalgOp's metadata along `dim` into the
114/// `newDim` at `iteratorTypes.size()` by:
115/// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
116/// 2. Appending a `newDim` to the domain of every indexing map.
117/// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing
118/// by potentially adding a `newDim` result to `map`.
119/// The preserved invariant is that `iteratorTypes.size()` is always equal to
120/// `map.getNumDims()` for every map in `indexingMaps`.
121///
122/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
123/// Return a vector that records the optional packing for each operand.
124/// Return failure if the packed indexing cannot be represented with a LinalgOp.
125///
126/// Further details:
127/// ================
128/// The current implementation of packing (i.e. data tiling) consists of
129/// rewriting a linearized strip-mined form into a higher-dimensional access.
130/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
131/// `I` into `4 * i + ii`, where `0 <= ii < 4`.
132/// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
133///
134/// This rewrite into higher dimensional access is not possible for general
135/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
136/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
137/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
138/// The rewrite of the access would be a form not representable in Linalg:
139/// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
140/// Note however that as `J` and `ii` iterate, the accesses do not have a
141/// particular alignment, so packing does not achieve alignment in this case
142///
143/// In the future, we may want to consider a mixed-form that allows some
144/// alignment in the presence of multiple accesses:
145/// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
146/// And would rewrite accesses as:
147/// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
148static FailureOr<SmallVector<std::optional<int64_t>>>
149packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps,
150 SmallVectorImpl<utils::IteratorType> &iteratorTypes,
151 int64_t dim) {
152 int64_t newDim = iteratorTypes.size();
153 iteratorTypes.push_back(Elt: iteratorTypes[dim]);
154
155 SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
156 indexingMaps.size(), std::nullopt);
157 SmallVector<AffineMap> newMaps;
158 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
159 ++operandIdx) {
160 AffineMap map = indexingMaps[operandIdx];
161
162 // Add the `newDim` to map whatever the case.
163 assert(map.getNumDims() == newDim && "num dims invariant violation");
164 map = map.shiftDims(shift: 1, offset: newDim);
165
166 // Get the at-most-1 index of the result that is a function of `dim`.
167 // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
168 // logically chunks dimension `dim` into `K * dim + newDim`, where the
169 // packing factor `K` is specified separately.
170 assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
171 "num results invariant violation");
172 auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
173 if (!maybeOperandDimensionToPack.has_value()) {
174 newMaps.push_back(Elt: map);
175 continue;
176 }
177
178 // We can only pack AffineDimExpr atm.
179 if (!isa<AffineDimExpr>(Val: map.getResult(idx: maybeOperandDimensionToPack.value())))
180 return failure();
181
182 // Add `newDim` to the results of the map.
183 map = map.insertResult(expr: Builder(map.getContext()).getAffineDimExpr(position: newDim),
184 pos: map.getNumResults());
185 newMaps.push_back(Elt: map);
186
187 // Record the that `operandIdx` is packed.
188 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
189 }
190 indexingMaps = newMaps;
191
192 return packedDimPerIndexingMap;
193}
194
195namespace {
196
197/// Helper struct to encode packing along one dimension of a LinalgOp.
198struct PackedOperandsDim {
199 OpFoldResult packedSize;
200 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
201};
202
203/// Helper struct to encode packing along all dimensions of a LinalgOp.
204struct PackedOperandsDimList {
205 void pushBack(PackedOperandsDim &&packedOperandsDims) {
206 spec.emplace_back(Args&: packedOperandsDims);
207 }
208 /// Return all the dims that have been packed for operand @ `operandPos`.
209 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
210 /// Return all the pack sizes by which an operand @ `operandPos` is packed.
211 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
212
213private:
214 SmallVector<PackedOperandsDim> spec;
215};
216
217} // namespace
218
219FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
220 linalg::PackOp packOp,
221 bool lowerPadLikeWithInsertSlice) {
222 // 1. Filter out NYI cases.
223 auto packedTensorType =
224 cast<RankedTensorType>(Val: packOp->getResultTypes().front());
225 if (llvm::any_of(Range: packOp.getStaticInnerTiles(), P: ShapedType::isDynamic)) {
226 return rewriter.notifyMatchFailure(
227 arg&: packOp,
228 msg: "non-static shape NYI, needs a more powerful tensor.expand_shape op");
229 }
230
231 Location loc = packOp->getLoc();
232 OpBuilder::InsertionGuard g(rewriter);
233 rewriter.setInsertionPoint(packOp);
234
235 // 2. Compute the permutation vector to shuffle packed shape into the shape
236 // before any outer or inner permutations have been applied.
237 PackingMetadata packingMetadata = computePackingMetadata(
238 packedRank: packedTensorType.getRank(), innerDimPos: packOp.getInnerDimsPos());
239 SmallVector<int64_t> packedToStripMinedShapePerm =
240 getPackInverseDestPerm(packOp);
241
242 // 3. Compute the stripMinedShape: this is the packed shape before any outer
243 // or inner permutations have been applied.
244 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
245 applyPermutationToVector(inVec&: stripMinedShape, permutation: packedToStripMinedShapePerm);
246
247 // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
248 SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
249 rewriter.getIndexAttr(value: 0));
250 SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
251 rewriter.getIndexAttr(value: 0));
252 for (auto [pos, innerSize] :
253 llvm::zip_equal(t: packOp.getInnerDimsPos(), u: packOp.getMixedTiles())) {
254 int outerPos =
255 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
256 OpFoldResult origSize =
257 tensor::getMixedSize(builder&: rewriter, loc, value: packOp.getSource(), dim: pos);
258 OpFoldResult outerSize =
259 tensor::getMixedSize(builder&: rewriter, loc, value: packOp.getDest(), dim: outerPos);
260 AffineExpr s0, d0, d1;
261 bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1);
262 bindSymbols(ctx: rewriter.getContext(), exprs&: s0);
263 auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, result: d0 * s0 - d1);
264 highs[pos] = affine::makeComposedFoldedAffineApply(
265 b&: rewriter, loc, map, operands: {outerSize, origSize, innerSize});
266 }
267 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
268 type: RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
269 reassociation: packingMetadata.reassociations);
270 Value paddingValue = packOp.getPaddingValue();
271 if (!paddingValue) {
272 paddingValue = rewriter.create<arith::ConstantOp>(
273 location: loc, args: rewriter.getZeroAttr(type: getElementTypeOrSelf(type: collapsed)));
274 }
275 auto padOp =
276 rewriter.create<tensor::PadOp>(location: loc, args&: collapsed, args: packOp.getSource(), args&: lows,
277 args&: highs, args&: paddingValue, /*nofold=*/args: false);
278
279 LLVM_DEBUG(
280 DBGSNL(); DBGSNL();
281 DBGS() << "insertPositions: "
282 << llvm::interleaved(packingMetadata.insertPositions);
283 DBGSNL(); DBGS() << "outerPositions: "
284 << llvm::interleaved(packingMetadata.outerPositions);
285 DBGSNL(); DBGS() << "packedShape: "
286 << llvm::interleaved(packedTensorType.getShape());
287 DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
288 << llvm::interleaved(packedToStripMinedShapePerm);
289 DBGSNL();
290 DBGS() << "reassociations: "
291 << llvm::interleaved(llvm::map_range(
292 packingMetadata.reassociations, stringifyReassocIndices));
293 DBGSNL();
294 DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
295 DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
296
297 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
298 // Pack ops which operate as simple pads may not produce legal
299 // tensor.insert_slice operations when the packed type does not rank reduce
300 // to the padded type.
301 SliceVerificationResult rankReduces =
302 isRankReducedType(originalType: packedTensorType, candidateReducedType: padOp.getResultType());
303
304 if (rankReduces == SliceVerificationResult::Success) {
305 // This pack is just a plain pad.
306 // Just insert the pad in the higher ranked tensor.
307 // Offsets.
308 SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
309 rewriter.getIndexAttr(value: 0));
310 // Strides.
311 SmallVector<OpFoldResult> ones(packOp.getDestRank(),
312 rewriter.getIndexAttr(value: 1));
313 SmallVector<OpFoldResult> sizes =
314 tensor::getMixedSizes(builder&: rewriter, loc, value: packOp.getDest());
315
316 auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
317 location: loc, /*source=*/args&: padOp, /*dest=*/args: packOp.getDest(),
318 /*offsets=*/args&: zeros, args&: sizes, /*strides=*/args&: ones);
319
320 LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
321
322 rewriter.replaceOp(op: packOp, newValues: insertSliceOp->getResults());
323
324 return LowerPackResult{.padOp: padOp, /*reshapeOp=*/.expandShapeOp: nullptr,
325 /*transposeOp=*/nullptr};
326 }
327 }
328
329 // 5. Expand from the padded result to the stripMinedShape.
330 auto expandShapeResultType =
331 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
332 auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
333 location: loc, args&: expandShapeResultType, args: padOp.getResult(),
334 args&: packingMetadata.reassociations);
335
336 // 6. Transpose stripMinedShape to packedShape.
337 SmallVector<int64_t> transpPerm =
338 invertPermutationVector(permutation: packedToStripMinedShapePerm);
339 auto transposeOp = rewriter.create<linalg::TransposeOp>(
340 location: loc, args: reshapeOp.getResult(), args: packOp.getDest(), args&: transpPerm);
341
342 LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
343 DBGS() << "reshape op: " << reshapeOp; DBGSNL();
344 DBGS() << "transpPerm: " << llvm::interleaved(transpPerm);
345 DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
346
347 // 7. Replace packOp by transposeOp.
348 rewriter.replaceOp(op: packOp, newValues: transposeOp->getResults());
349
350 return LowerPackResult{.padOp: padOp, .expandShapeOp: reshapeOp, .transposeOp: transposeOp};
351}
352
353FailureOr<LowerUnPackOpResult>
354linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
355 bool lowerUnpadLikeWithExtractSlice) {
356 Location loc = unPackOp->getLoc();
357 OpBuilder::InsertionGuard g(rewriter);
358 rewriter.setInsertionPoint(unPackOp);
359
360 RankedTensorType packedTensorType = unPackOp.getSourceType();
361 int64_t packedRank = packedTensorType.getRank();
362
363 OpFoldResult zero = rewriter.getIndexAttr(value: 0), one = rewriter.getIndexAttr(value: 1);
364 auto destTensorType = cast<RankedTensorType>(Val: unPackOp.getDest().getType());
365 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
366 // This unpack is just a plain unpad.
367 // Just extract the slice from the higher ranked tensor.
368 ArrayRef<int64_t> destShape = destTensorType.getShape();
369 // The inner dimensions stay the same as the destination tensor, but the
370 // outer ones are additional 1s.
371 SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
372 sizes.append(RHS: tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getDest()));
373
374 auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
375 location: loc, args&: destTensorType, args: unPackOp.getSource(),
376 args: SmallVector<OpFoldResult>(packedRank, zero), args&: sizes,
377 args: SmallVector<OpFoldResult>(packedRank, one));
378
379 rewriter.replaceOp(op: unPackOp, newValues: extractSliceOp->getResults());
380
381 return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
382 /*reshapeOp=*/.collapseShapeOp: nullptr, .extractSliceOp: extractSliceOp};
383 }
384
385 // 1. Compute the permutation vector to shuffle packed shape into the shape
386 // before any outer or inner permutations have been applied.
387 PackingMetadata packingMetadata;
388 SmallVector<int64_t> packedToStripMinedShapePerm =
389 getUnPackInverseSrcPerm(unPackOp, metadata&: packingMetadata);
390
391 // 2. Compute the stripMinedShape: this is the packed shape without outer and
392 // inner permutations.
393 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
394 applyPermutationToVector(inVec&: stripMinedShape, permutation: packedToStripMinedShapePerm);
395
396 // 3. Transpose packedShape to stripMinedShape.
397 RankedTensorType stripMinedTensorType =
398 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
399 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
400 type: stripMinedTensorType, reassociation: packingMetadata.reassociations);
401
402 // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
403 // permutation.
404 SmallVector<OpFoldResult, 4> dims =
405 tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getSource());
406 applyPermutationToVector(inVec&: dims, permutation: packedToStripMinedShapePerm);
407 auto emptyOp = rewriter.create<tensor::EmptyOp>(
408 location: loc, args&: dims, args: stripMinedTensorType.getElementType());
409 auto transposeOp = rewriter.create<linalg::TransposeOp>(
410 location: loc, args: unPackOp.getSource(), args&: emptyOp, args&: packedToStripMinedShapePerm);
411
412 LLVM_DEBUG(
413 DBGSNL(); DBGSNL();
414 DBGS() << "insertPositions: "
415 << llvm::interleaved(packingMetadata.insertPositions);
416 DBGSNL(); DBGS() << "packedShape: "
417 << llvm::interleaved(packedTensorType.getShape());
418 DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
419 << llvm::interleaved(packedToStripMinedShapePerm);
420 DBGSNL();
421 DBGS() << "reassociations: "
422 << llvm::interleaved(llvm::map_range(
423 packingMetadata.reassociations, stringifyReassocIndices));
424 DBGSNL();
425 DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
426 DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
427
428 // 4. Collapse from the stripMinedShape to the padded result.
429 auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
430 location: loc, args&: collapsedType, args: transposeOp->getResult(idx: 0),
431 args&: packingMetadata.reassociations);
432
433 // 5. ExtractSlice.
434 int64_t destRank = destTensorType.getRank();
435 auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
436 location: loc, args&: destTensorType, args: reshapeOp->getResult(idx: 0),
437 args: SmallVector<OpFoldResult>(destRank, zero),
438 args: tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getDest()),
439 args: SmallVector<OpFoldResult>(destRank, one));
440
441 // 6. Inject a copy to preserve DPS.
442 auto copyOp = rewriter.create<linalg::CopyOp>(
443 location: loc, args: extractSliceOp->getResult(idx: 0), args: unPackOp.getDest());
444
445 // 7. Replace unPackOp by copyOp.
446 rewriter.replaceOp(op: unPackOp, newValues: copyOp->getResults());
447
448 return LowerUnPackOpResult{.emptyOp: emptyOp, .transposeOp: transposeOp, .collapseShapeOp: reshapeOp, .extractSliceOp: extractSliceOp};
449}
450
451SmallVector<int64_t>
452PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
453 SmallVector<int64_t> res;
454 for (auto &i : spec) {
455 if (!i.packedDimForEachOperand[operandPos].has_value())
456 continue;
457 res.push_back(Elt: i.packedDimForEachOperand[operandPos].value());
458 }
459 return res;
460}
461
462SmallVector<OpFoldResult>
463PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
464 SmallVector<OpFoldResult> res;
465 for (auto &i : spec) {
466 if (!i.packedDimForEachOperand[operandPos].has_value())
467 continue;
468 res.push_back(Elt: i.packedSize);
469 }
470 return res;
471}
472
473/// Implement packing of a single LinalgOp by performing packing by
474/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
475/// Return the packed Linalg op on success, failure otherwise.
476FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
477 linalg::LinalgOp linalgOp,
478 ArrayRef<OpFoldResult> packedSizes) {
479 if (packedSizes.size() != linalgOp.getNumLoops()) {
480 return rewriter.notifyMatchFailure(arg&: linalgOp,
481 msg: "incorrect number of pack sizes");
482 }
483
484 Location loc = linalgOp->getLoc();
485 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
486 SmallVector<utils::IteratorType> iteratorTypes =
487 linalgOp.getIteratorTypesArray();
488 LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"
489 << "maps: " << llvm::interleaved(indexingMaps) << "\n"
490 << "iterators: " << llvm::interleaved(iteratorTypes)
491 << "\n");
492
493 SmallVector<linalg::PackOp> packOps;
494 SmallVector<linalg::UnPackOp> unPackOps;
495 // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
496 PackedOperandsDimList listOfPackedOperandsDim;
497 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
498 std::optional<int64_t> maybeConstant = getConstantIntValue(ofr: packedSizes[i]);
499 // Skip tile sizes explicitly set to 0.
500 if (maybeConstant.has_value() && maybeConstant.value() == 0)
501 continue;
502
503 PackedOperandsDim packedOperandsDims;
504 packedOperandsDims.packedSize = packedSizes[i];
505 FailureOr<SmallVector<std::optional<int64_t>>>
506 maybePackedDimForEachOperand =
507 packLinalgMetadataOnce(indexingMaps, iteratorTypes, dim: i);
508 if (failed(Result: maybePackedDimForEachOperand))
509 return failure();
510 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
511 listOfPackedOperandsDim.pushBack(packedOperandsDims: std::move(packedOperandsDims));
512
513 LLVM_DEBUG(
514 DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
515 << "\n"
516 << "maps: " << llvm::interleaved(indexingMaps) << "\n"
517 << "iterators: " << llvm::interleaved(iteratorTypes) << "\n"
518 << "packedDimForEachOperand: "
519 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand)
520 << "\n");
521 }
522
523 // Step 2. Propagate packing to all LinalgOp operands.
524 SmallVector<Value> inputsAndInits, results;
525 SmallVector<OpOperand *> initOperands =
526 llvm::to_vector(Range: llvm::make_pointer_range(Range: linalgOp.getDpsInitsMutable()));
527 SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
528 for (const auto &operandsList : {inputOperands, initOperands}) {
529 for (OpOperand *opOperand : operandsList) {
530 int64_t pos = opOperand->getOperandNumber();
531 Value operand = opOperand->get();
532 SmallVector<int64_t> innerPos =
533 listOfPackedOperandsDim.extractPackedDimsForOperand(operandPos: pos);
534 SmallVector<OpFoldResult> innerPackSizes =
535 listOfPackedOperandsDim.extractPackSizesForOperand(operandPos: pos);
536 LLVM_DEBUG(DBGS() << "operand: " << operand << "\n"
537 << "innerPos: " << llvm::interleaved(innerPos) << "\n"
538 << "innerPackSizes: "
539 << llvm::interleaved(innerPackSizes) << "\n");
540 if (innerPackSizes.empty()) {
541 inputsAndInits.push_back(Elt: operand);
542 continue;
543 }
544 Value dest = linalg::PackOp::createDestinationTensor(
545 b&: rewriter, loc, source: operand, innerTileSizes: innerPackSizes, innerDimsPos: innerPos,
546 /*outerDimsPerm=*/{});
547 ShapedType operandType = cast<ShapedType>(Val: operand.getType());
548 bool areConstantTiles =
549 llvm::all_of(Range&: innerPackSizes, P: [](OpFoldResult tile) {
550 return getConstantIntValue(ofr: tile).has_value();
551 });
552 if (areConstantTiles && operandType.hasStaticShape() &&
553 !linalg::PackOp::requirePaddingValue(
554 inputShape: operandType.getShape(), innerDimsPos: innerPos,
555 outputShape: cast<ShapedType>(Val: dest.getType()).getShape(), outerDimsPerm: {},
556 innerTiles: innerPackSizes)) {
557 packOps.push_back(Elt: rewriter.create<linalg::PackOp>(
558 location: loc, args&: operand, args&: dest, args&: innerPos, args&: innerPackSizes));
559 } else {
560 // TODO: value of the padding attribute should be determined by
561 // consumers.
562 auto zeroAttr =
563 rewriter.getZeroAttr(type: getElementTypeOrSelf(type: dest.getType()));
564 Value zero = rewriter.create<arith::ConstantOp>(location: loc, args&: zeroAttr);
565 packOps.push_back(Elt: rewriter.create<linalg::PackOp>(
566 location: loc, args&: operand, args&: dest, args&: innerPos, args&: innerPackSizes, args&: zero));
567 }
568 inputsAndInits.push_back(Elt: packOps.back());
569 }
570 }
571
572 // Step 3. Build the packed op, use the type of `inits` as result types.
573 ValueRange inputs =
574 ValueRange{inputsAndInits}.take_front(n: linalgOp.getNumDpsInputs());
575 ValueRange inits =
576 ValueRange{inputsAndInits}.take_back(n: linalgOp.getNumDpsInits());
577 auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
578 location: linalgOp.getLoc(), args: inits.getTypes(), args&: inputs, args&: inits, args&: indexingMaps,
579 args&: iteratorTypes);
580 packedLinalgOp.getRegion().takeBody(other&: linalgOp->getRegion(index: 0));
581
582 // Step 4. Propagate packing to all the op results.
583 for (OpResult result : packedLinalgOp->getResults()) {
584 int64_t resultNum = result.getResultNumber();
585 linalg::PackOp maybePackedInit =
586 inits[resultNum].getDefiningOp<linalg::PackOp>();
587 if (!maybePackedInit) {
588 results.push_back(Elt: result);
589 continue;
590 }
591 // Build the symmetrical UnPackOp to the existing PackOp.
592 unPackOps.push_back(Elt: rewriter.create<linalg::UnPackOp>(
593 location: packedLinalgOp->getLoc(), args&: result, args: maybePackedInit.getSource(),
594 args: maybePackedInit.getInnerDimsPos(), args: maybePackedInit.getMixedTiles()));
595 results.push_back(Elt: unPackOps.back());
596 }
597
598 // Step 5. Replace `linalgOp`.
599 rewriter.replaceOp(op: linalgOp, newValues: results);
600
601 // Return packedLinalgOp.
602 return PackResult{.packOps: packOps,
603 .packedLinalgOp: cast<linalg::LinalgOp>(Val: packedLinalgOp.getOperation()),
604 .unPackOps: unPackOps};
605}
606
607//===----------------------------------------------------------------------===//
608// packTranspose transformation.
609//===----------------------------------------------------------------------===//
610
611/// Return a copy of `tensorType` after permutation by `permutationVector`.
612// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
613// but this would introduce a dependence on Dialect in IR.
614// TODO: Restructure.
615static RankedTensorType permuteShape(RankedTensorType tensorType,
616 ArrayRef<int64_t> permutationVector) {
617 SmallVector<int64_t> shape(tensorType.getShape());
618 applyPermutationToVector(inVec&: shape, permutation: permutationVector);
619 return RankedTensorType::Builder(tensorType).setShape(shape);
620}
621
622/// Return a new GenericOp obtained by transposing opOperand by the permutation
623/// vector:
624/// - the corresponding indexing map is transposed by `permutation`
625/// - the corresponding operand value is replaced by `transposedValue`
626/// `linalgOp` is replaced by the return op in the process.
627/// Asserts that `transposedValue` is of the proper transposed ShapedType.
628static LinalgOp transposeOneLinalgOperandAndReplace(
629 RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
630 ArrayRef<int64_t> permutation, Value transposedValue) {
631 // Sanity check the operand.
632 assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
633
634 // Sanity check of the expected transposed tensor type.
635 auto tensorType = permuteShape(
636 tensorType: cast<RankedTensorType>(Val: opOperand.get().getType()), permutationVector: permutation);
637 (void)tensorType;
638 assert(tensorType == transposedValue.getType() &&
639 "expected tensor type mismatch");
640
641 // Compute the transposed indexing map.
642 // Sigh unsigned pollution.
643 SmallVector<unsigned> tmpTransposition = llvm::to_vector(
644 Range: llvm::map_range(C&: permutation, F: [](int64_t i) -> unsigned { return i; }));
645 AffineMap permutationMap =
646 AffineMap::getPermutationMap(permutation: tmpTransposition, context: rewriter.getContext());
647 AffineMap transposedMap =
648 permutationMap.compose(map: linalgOp.getMatchingIndexingMap(opOperand: &opOperand));
649
650 // Set the transposed indexing map in the proper position.
651 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
652 indexingMaps[linalgOp.getIndexingMapIndex(opOperand: &opOperand)] = transposedMap;
653 // Set the transposedValue in the proper operand position.
654 SmallVector<Value> operands = linalgOp->getOperands();
655 operands[opOperand.getOperandNumber()] = transposedValue;
656
657 ValueRange operandsRef(operands);
658 auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
659 /*location=*/linalgOp->getLoc(),
660 /*resultTensorTypes=*/
661 args: operandsRef.drop_front(n: linalgOp.getNumDpsInputs()).getTypes(),
662 /*inputs=*/args: operandsRef.take_front(n: linalgOp.getNumDpsInputs()),
663 /*outputs=*/args: operandsRef.drop_front(n: linalgOp.getNumDpsInputs()),
664 /*indexingMaps=*/args&: indexingMaps,
665 /*iteratorTypes=*/args: linalgOp.getIteratorTypesArray());
666 transposedGenericOp.getRegion().takeBody(other&: linalgOp->getRegion(index: 0));
667 rewriter.replaceOp(op: linalgOp, newValues: transposedGenericOp->getResults());
668
669 return cast<linalg::LinalgOp>(Val: transposedGenericOp.getOperation());
670}
671
672FailureOr<PackTransposeResult>
673linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
674 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
675 ArrayRef<int64_t> outerPerm,
676 ArrayRef<int64_t> innerPerm) {
677 Location loc = linalgOp.getLoc();
678
679 // Step 1. Transpose packOp.
680 rewriter.setInsertionPoint(packOp);
681 linalg::PackOp transposedPackOp =
682 packOp.createTransposedClone(b&: rewriter, loc, innerPermutation: innerPerm, outerPermutation: outerPerm);
683
684 if (!packOp.getResult().hasOneUse())
685 return rewriter.notifyMatchFailure(arg&: linalgOp, msg: "expect single pack use");
686
687 OpOperand &packUse = *packOp->getUses().begin();
688 if (packUse.getOwner() != linalgOp) {
689 return rewriter.notifyMatchFailure(
690 arg&: linalgOp, msg: "not a single use by the LinalgOp target");
691 }
692 if (maybeUnPackOp &&
693 (!linalgOp.isDpsInit(opOperand: &packUse) ||
694 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(opOperand: &packUse))) {
695 return rewriter.notifyMatchFailure(arg&: linalgOp,
696 msg: "not produced by the LinalgOp target");
697 }
698
699 // Step 2. Transpose linalgOp.
700 // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
701 // identity. Don't rely on it.
702 int64_t numLeadingDims = packOp.getSourceRank();
703 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
704 // Step 2.a. Compute the permutation on the whole operand.
705 // Leading part just reuse the outerPerm.
706 SmallVector<int64_t> permutation(outerPerm);
707 if (permutation.empty())
708 llvm::append_range(C&: permutation, R: llvm::seq<int64_t>(Begin: 0, End: numLeadingDims));
709 // Trailing part needs to reindex positions by `numLeadingDims`.
710 if (innerPerm.empty()) {
711 llvm::append_range(
712 C&: permutation,
713 R: llvm::seq<int64_t>(Begin: numLeadingDims, End: numLeadingDims + numTrailingDims));
714 } else {
715 llvm::append_range(C&: permutation,
716 R: llvm::map_range(C&: innerPerm, F: [&](int64_t pos) {
717 return numLeadingDims + pos;
718 }));
719 }
720 if (!isPermutationVector(interchange: permutation))
721 return rewriter.notifyMatchFailure(arg&: linalgOp, msg: "invalid permutation");
722
723 // Step 2.b. Save the transposedPackUse operand number in case we need to
724 // get the tied OpResult after `linalgOp` has been replaced.
725 int64_t packUseOperandNumber = packUse.getOperandNumber();
726 // Step 2.c. Actually perform the transposition.
727 rewriter.setInsertionPoint(linalgOp);
728 linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
729 rewriter, linalgOp, opOperand&: packUse, permutation, transposedValue: transposedPackOp.getResult());
730
731 // Step 3. Maybe transpose unPackOp.
732 linalg::UnPackOp transposedUnPackOp;
733 if (maybeUnPackOp) {
734 OpOperand &opOperand =
735 transposedLinalgOp->getOpOperand(idx: packUseOperandNumber);
736 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(opOperand: &opOperand);
737 rewriter.setInsertionPoint(maybeUnPackOp);
738 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
739 b&: rewriter, loc, transposedSource: transposedResult, innerPermutation: innerPerm, outerPermutation: outerPerm);
740
741 rewriter.replaceOp(op: maybeUnPackOp, newValues: transposedUnPackOp->getResults());
742 }
743
744 // Step 4. Finally, replace packOp now that we don't need it anymore.
745 rewriter.replaceOp(op: packOp, newValues: transposedPackOp->getResults());
746
747 return PackTransposeResult{.transposedPackOp: transposedPackOp, .transposedLinalgOp: transposedLinalgOp,
748 .transposedUnPackOp: transposedUnPackOp};
749}
750
751//===----------------------------------------------------------------------===//
752// packMatmulGreedily transformation.
753//===----------------------------------------------------------------------===//
754
755/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
756/// and n are proper parallel dimensions and k is a proper reduction
757/// dimension. Packing occurs by rewriting the op as a linalg.generic and
758/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
759/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
760/// to reorder {m, n, k} into one of the 8 possible forms. The outer
761/// dimensions of the operands are not permuted at this time, this is left for
762/// future work.
763FailureOr<PackResult>
764linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
765 ArrayRef<OpFoldResult> mnkPackedSizes,
766 ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
767 ArrayRef<int64_t> mnkOrder) {
768 assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
769 assert((mnkPaddedSizesNextMultipleOf.empty() ||
770 mnkPaddedSizesNextMultipleOf.size() == 3) &&
771 "num of packing sizes next multiple should be empty or of size 3");
772 assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
773 assert(isPermutationVector(mnkOrder) && "expected a permutation");
774
775 int64_t numLoops = linalgOp.getNumLoops();
776 if (numLoops <= 2) {
777 LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
778 << numLoops << "\nin: " << linalgOp << "\n");
779 return rewriter.notifyMatchFailure(
780 arg&: linalgOp, msg: "need 3+ loops to find a matmul to pack");
781 }
782
783 // Locally adjust the desired iterator position of mnk and packing sizes.
784 int64_t numPackedDims = mnkPackedSizes.size();
785 SmallVector<int64_t> mmnnkkPos(numPackedDims);
786 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
787 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
788 SmallVector<OpFoldResult> packedSizes(numPackedDims);
789 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
790 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
791 SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
792 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
793 paddedSizesNextMultipleOf[mnkOrder[i]] =
794 mnkPaddedSizesNextMultipleOf.empty() ? 0
795 : mnkPaddedSizesNextMultipleOf[i];
796 }
797
798 // 1. Infer dims that are important for matmul.
799 FailureOr<ContractionDimensions> maybeDimensions =
800 inferContractionDims(linalgOp);
801 if (failed(Result: maybeDimensions)) {
802 LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
803 << "\n");
804 return rewriter.notifyMatchFailure(arg&: linalgOp,
805 msg: "couldn't infer matmul iterators");
806 }
807
808 // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
809 // minor iterators. In cases with multiple options for m, n, k bias towards
810 // the most minor embedding.
811 // If we wanted a different normalization order, this is where it would have
812 // to plug a heuristic.
813 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
814 kPos = maybeDimensions->k.back();
815 LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
816 DBGS() << "Start packing generic op greedily with (m@" << mPos
817 << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
818 << "\n";);
819
820 // 2.a. Rewrite as a generic.
821 auto genericOp = dyn_cast<GenericOp>(Val: linalgOp.getOperation());
822 if (!genericOp) {
823 FailureOr<GenericOp> generalizeResult =
824 generalizeNamedOp(rewriter, linalgOp);
825 assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
826 genericOp = *generalizeResult;
827 }
828
829 // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
830 // iterators. Note that this only normalized the iteration order and does
831 // not change the indexings of any operand.
832 SmallVector<int64_t> permutation =
833 computePermutationVector(permSize: numLoops, positions: {mPos, nPos, kPos}, desiredPositions: mmnnkkPos);
834 LLVM_DEBUG(DBGS() << "perm: " << llvm::interleaved(permutation) << "\n");
835 // Sign .. unsigned pollution.
836 SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
837 FailureOr<GenericOp> interchangeResult =
838 interchangeGenericOp(rewriter, genericOp, interchangeVector: unsignedPerm);
839 assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
840 genericOp = *interchangeResult;
841 LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
842
843 // At this point, the op iterators are normalized to {leading, k, m, n}.
844 // The layouts induced by packing will always be:
845 // - LHS{leading_lhs, kk, mm}
846 // - RHS{leading_rhs, kk, nn}
847 // - RES{leading_res, mm, nn}
848 // If we wanted to change the packed order, we would reorder (k, m, n) to
849 // something else above.
850 //
851 // Additional permutations of the outer dims of the operands (i.e.
852 // leading_lhs, leading_rhs and leading_res) could follow by computing the
853 // desired outerPerm for each operand.
854 // This is left for future work.
855
856 // TODO: this creates too much IR, go use reifyResultShapes.
857 SmallVector<Range, 4> loopRanges =
858 cast<LinalgOp>(Val: genericOp.getOperation())
859 .createLoopRanges(b&: rewriter, loc: genericOp.getLoc());
860
861 // Add leading zeros to match numLoops, we only pack the last 3 dimensions
862 // post interchange.
863 LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: "
864 << llvm::interleaved(paddedSizesNextMultipleOf) << "\n"
865 << "loopRanges: "
866 << llvm::interleaved(llvm::map_range(
867 loopRanges, [](Range r) { return r.size; }))
868 << "\n");
869 SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
870 rewriter.getIndexAttr(value: 0));
871 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
872 if (paddedSizesNextMultipleOf[i] == 0) {
873 adjustedPackedSizes.push_back(Elt: packedSizes[i]);
874 continue;
875 }
876 AffineExpr d0, s0;
877 bindDims(ctx: rewriter.getContext(), exprs&: d0);
878 bindSymbols(ctx: rewriter.getContext(), exprs&: s0);
879 adjustedPackedSizes.push_back(Elt: affine::makeComposedFoldedAffineApply(
880 b&: rewriter, loc: genericOp->getLoc(), expr: d0.ceilDiv(other: s0) * s0,
881 operands: {loopRanges[adjustedPackedSizes.size()].size,
882 rewriter.getIndexAttr(value: paddedSizesNextMultipleOf[i])}));
883 }
884 LLVM_DEBUG(DBGS() << "adjustedPackedSizes: "
885 << llvm::interleaved(adjustedPackedSizes) << "\n");
886
887 // TODO: If we wanted to give the genericOp a name after packing, after
888 // calling `pack` would be a good time. One would still need to check that
889 // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
890 // also allow degenerate matmul cases (i.e. matvec, dot).
891 return pack(rewriter, linalgOp: genericOp, packedSizes: adjustedPackedSizes);
892}
893
894//===----------------------------------------------------------------------===//
895// Transformations exposed as rewrite patterns.
896//===----------------------------------------------------------------------===//
897
898LinalgTilingOptions &
899mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
900 assert(!tileSizeComputationFunction && "tile sizes already set");
901 SmallVector<int64_t, 4> tileSizes(ts);
902 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
903 OpBuilder::InsertionGuard guard(b);
904 b.setInsertionPointToStart(
905 &op->getParentOfType<func::FuncOp>().getBody().front());
906 return llvm::to_vector<4>(Range: map_range(C: tileSizes, F: [&](int64_t s) {
907 Value v = b.create<arith::ConstantIndexOp>(location: op->getLoc(), args&: s);
908 return v;
909 }));
910 };
911 return *this;
912}
913
914LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
915 memref::CopyOp copyOp, PatternRewriter &rewriter) const {
916 return vectorizeCopy(builder&: rewriter, copyOp);
917}
918
919/// Filling `dest` using FillOp constant padding value if possible.
920/// Otherwise, generate a tensor::GenerateOp.
921Value DecomposePadOpPattern::createFillOrGenerateOp(
922 RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
923 const SmallVector<Value> &dynSizes) const {
924 auto padValue = padOp.getConstantPaddingValue();
925 if (padValue) {
926 // Move the padding value defined inside the PadOp block to outside.
927 if (padValue.getParentBlock() == &padOp.getRegion().front())
928 rewriter.moveOpBefore(op: padValue.getDefiningOp(), existingOp: padOp);
929 return rewriter.create<FillOp>(location: padOp.getLoc(), args&: padValue, args&: dest).result();
930 }
931
932 // Fill could not be optimized: Lower to tensor::GenerateOp with region.
933 auto generateOp = rewriter.create<tensor::GenerateOp>(
934 location: padOp.getLoc(), args: padOp.getResultType(), args: dynSizes);
935 // Copy region to new op.
936 IRMapping bvm;
937 padOp.getRegion().cloneInto(dest: &generateOp.getRegion(), mapper&: bvm);
938 return generateOp;
939}
940
941LogicalResult
942DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
943 PatternRewriter &rewriter) const {
944 // Given an OpFoldResult, return an index-typed value.
945 auto getIdxValue = [&](OpFoldResult ofr) {
946 if (auto val = llvm::dyn_cast_if_present<Value>(Val&: ofr))
947 return val;
948 return rewriter
949 .create<arith::ConstantIndexOp>(
950 location: padOp.getLoc(), args: cast<IntegerAttr>(Val: cast<Attribute>(Val&: ofr)).getInt())
951 .getResult();
952 };
953
954 auto resultType = padOp.getResultType();
955 // Compute size of EmptyOp. Any combination of static/dynamic is supported.
956 SmallVector<Value> dynSizes;
957 SmallVector<int64_t> staticSizes;
958 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
959 if (resultType.isDynamicDim(idx: dim)) {
960 auto srcSize = getIdxValue(tensor::getMixedSize(builder&: rewriter, loc: padOp.getLoc(),
961 value: padOp.getSource(), dim));
962 // Add low and high padding value.
963 auto plusLow = rewriter.createOrFold<arith::AddIOp>(
964 location: padOp.getLoc(), args&: srcSize, args: getIdxValue(padOp.getMixedLowPad()[dim]));
965 auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
966 location: padOp.getLoc(), args&: plusLow, args: getIdxValue(padOp.getMixedHighPad()[dim]));
967 dynSizes.push_back(Elt: plusHigh);
968 }
969 staticSizes.push_back(Elt: resultType.getDimSize(idx: dim));
970 }
971
972 // Init tensor and fill it with padding.
973 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
974 location: padOp.getLoc(), args&: staticSizes, args: resultType.getElementType(), args&: dynSizes);
975 Value fill = createFillOrGenerateOp(rewriter, padOp, dest: emptyTensor, dynSizes);
976
977 // Generate a InsertSliceOp for copying the PadOp source.
978 auto sourceType = padOp.getSourceType();
979 // Compute size of source of tensor::PadOp.
980 SmallVector<OpFoldResult> srcSizes =
981 tensor::getMixedSizes(builder&: rewriter, loc: padOp.getLoc(), value: padOp.getSource());
982 // Strides of InsertSliceOp are all 1.
983 SmallVector<OpFoldResult> strides(sourceType.getRank(),
984 rewriter.getIndexAttr(value: 1));
985 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
986 op: padOp, args: padOp.getSource(), args&: fill, args: padOp.getMixedLowPad(), args&: srcSizes,
987 args&: strides);
988
989 return success();
990}
991
992LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
993 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
994 if (!sliceOp.hasUnitStride())
995 return failure();
996
997 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
998 if (!padOp)
999 return failure();
1000
1001 bool zeroSliceGuard = true;
1002 if (controlFn) {
1003 if (std::optional<bool> control = controlFn(sliceOp))
1004 zeroSliceGuard = *control;
1005 else
1006 return failure();
1007 }
1008
1009 FailureOr<TilingResult> tilingResult =
1010 tensor::bubbleUpPadSlice(b&: rewriter, padOp, offsets: sliceOp.getMixedOffsets(),
1011 sizes: sliceOp.getMixedSizes(), generateZeroSliceGuard: zeroSliceGuard);
1012 if (failed(Result: tilingResult))
1013 return failure();
1014
1015 RankedTensorType sourceType = sliceOp.getSourceType();
1016 RankedTensorType resultType = sliceOp.getResultType();
1017
1018 // If the extract_slice is not rank-reduced, all shapes are static and the
1019 // data source is actually used. Rewrite into pad(extract_slice(x)).
1020 if (sourceType.getRank() == resultType.getRank()) {
1021 rewriter.replaceOp(op: sliceOp, newValues: tilingResult->tiledValues);
1022 return success();
1023 }
1024
1025 // Handle rank-reduced slice by creating another extract_slice op.
1026 Value rankReduced = tensor::createCanonicalRankReducingExtractSliceOp(
1027 b&: rewriter, loc: sliceOp.getLoc(), tensor: tilingResult->tiledValues[0], targetType: resultType);
1028
1029 rewriter.replaceOp(op: sliceOp, newValues: rankReduced);
1030 return success();
1031}
1032
1033/// If padding value is set, returns a tensor.pad Op for the source tensor,
1034/// with the output shape matching the output of `packOp`. Otherwise, returns
1035/// the source directly.
1036///
1037/// This method assumes that all outer dims for this pack Op are 1.
1038static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
1039 linalg::PackOp packOp) {
1040 Value input = packOp.getSource();
1041 if (!packOp.getPaddingValue()) {
1042 return input;
1043 }
1044
1045 assert(llvm::all_of(packOp.getAllOuterDims(),
1046 [](int64_t val) { return val == 1; }) &&
1047 "some outer dims are != 1");
1048
1049 Location loc = packOp.getLoc();
1050 ShapedType inputType = packOp.getSourceType();
1051 int64_t inputRank = inputType.getRank();
1052
1053 DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1054 packOp.getDimAndTileMapping();
1055
1056 // The sizes of dynamic tiles
1057 SmallVector<Value> dynamicTileSizes;
1058
1059 // Collect dims for the padded shape.
1060 SmallVector<int64_t> paddedShape;
1061 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1062 // 1. Non-tiled outer dims.
1063 // These dims should be 1 and we simply preserve them.
1064 if (!tileAndPosMapping.count(Val: dimIdx)) {
1065 int64_t inputDimSize = inputType.getDimSize(idx: dimIdx);
1066 assert(inputDimSize == 1 &&
1067 "with all outer dims == 1, this non-tiled input dim should be 1!");
1068 paddedShape.push_back(Elt: inputDimSize);
1069 continue;
1070 }
1071
1072 // 2. Tiled outer dims
1073 // As all outer dims == 1, it is safe to use the tile size for the padded
1074 // shape.
1075 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(Val: dimIdx);
1076
1077 // 2.1 Static tile sizes
1078 std::optional<int64_t> cstTileSize = getConstantIntValue(ofr: tileSizeForDim);
1079 if (cstTileSize.has_value()) {
1080 paddedShape.push_back(Elt: cstTileSize.value());
1081 continue;
1082 }
1083
1084 // 2.2 Dynamic tile sizes
1085 paddedShape.push_back(Elt: ShapedType::kDynamic);
1086
1087 // Get the value that holds the dynamic size.
1088 dynamicTileSizes.push_back(Elt: llvm::dyn_cast<Value>(Val&: tileSizeForDim));
1089 }
1090 auto resultType =
1091 RankedTensorType::get(shape: paddedShape, elementType: inputType.getElementType());
1092 return tensor::createPadHighOp(resType: resultType, source: input, pad: packOp.getPaddingValue(),
1093 /*nofold=*/false, loc, builder,
1094 dynOutDims: dynamicTileSizes);
1095}
1096
1097// Normalizes a permutation on a higher rank space to its actual size, e.g.
1098// perm = [1, 4, 2]
1099// becomes
1100// norm = [0, 2, 1]
1101static SmallVector<int64_t>
1102getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) {
1103 constexpr int64_t kNonTiledMarker = -1;
1104 SmallVector<int64_t> vec(rank, kNonTiledMarker);
1105 for (auto [index, value] : llvm::enumerate(First&: perm))
1106 vec[value] = index;
1107 SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector(
1108 C&: vec, Pred: [&](int64_t v) { return v != kNonTiledMarker; });
1109 // This inverts the permutation in addition to normalizing so invert back.
1110 return invertPermutationVector(permutation: normalizedPerm);
1111}
1112
1113// Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1114// assuming rank reduction of unit outer dims.
1115static SmallVector<int64_t>
1116getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
1117 ArrayRef<int64_t> innerDimsPos,
1118 ArrayRef<int64_t> outerDimsPerm) {
1119 SmallVector<int64_t> rankReducedOuterDimsPerm;
1120 SmallVector<int64_t> outerDims;
1121 SmallVector<int64_t> innerDims;
1122 int64_t dim = 0;
1123 int64_t unpackedRank = shape.size();
1124 for (auto i : llvm::seq<unsigned>(Begin: 0, End: unpackedRank)) {
1125 if (llvm::is_contained(Range&: innerDimsPos, Element: i)) {
1126 innerDims.push_back(Elt: dim++);
1127 continue;
1128 }
1129 if (shape[i] == 1)
1130 continue;
1131 outerDims.push_back(Elt: dim++);
1132 if (!outerDimsPerm.empty())
1133 rankReducedOuterDimsPerm.push_back(Elt: outerDimsPerm[i]);
1134 }
1135
1136 // Get the position of the inner dims after permutation.
1137 SmallVector<int64_t> innerPerm =
1138 getPackUnpackNormalizedPerm(rank: unpackedRank, perm: innerDimsPos);
1139 applyPermutationToVector<int64_t>(inVec&: innerDims, permutation: innerPerm);
1140
1141 // Ditto for the outer dims.
1142 SmallVector<int64_t> perm = outerDims;
1143
1144 rankReducedOuterDimsPerm =
1145 getPackUnpackNormalizedPerm(rank: unpackedRank, perm: rankReducedOuterDimsPerm);
1146 if (!rankReducedOuterDimsPerm.empty())
1147 applyPermutationToVector<int64_t>(inVec&: perm, permutation: rankReducedOuterDimsPerm);
1148
1149 // The tile always ends up as the inner most dims after packing.
1150 perm.append(RHS: innerDims);
1151
1152 return perm;
1153}
1154
1155LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
1156 linalg::PackOp packOp, PatternRewriter &rewriter) const {
1157 // TODO: support the case that outer dimensions are not all 1s. A
1158 // tensor.expand_shape will be generated in this case.
1159 if (llvm::any_of(Range: packOp.getAllOuterDims(),
1160 P: [](int64_t dim) { return dim != 1; })) {
1161 return rewriter.notifyMatchFailure(
1162 arg&: packOp, msg: "not all outer dimensions of the result are 1s");
1163 }
1164
1165 Attribute zeroIdxAttr = rewriter.getIndexAttr(value: 0);
1166 Attribute oneIdxAttr = rewriter.getIndexAttr(value: 1);
1167 Location loc = packOp.getLoc();
1168
1169 Value input = getPackOpSourceOrPaddedSource(builder&: rewriter, packOp);
1170 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1171 packOp.getDimAndTileMapping();
1172 int64_t srcRank = packOp.getSourceRank();
1173 int64_t destRank = packOp.getDestRank();
1174 int64_t numTiles = destRank - srcRank;
1175
1176 // 1. Extract the inner tile sizes.
1177 // Where possible, values are replaced with constant attributes (to match the
1178 // behaviour of `getPackOpSourceOrPaddedSource`).
1179 SmallVector<OpFoldResult> tileSizes;
1180 for (auto i : llvm::seq<unsigned>(Begin: 0, End: srcRank)) {
1181 if (dimAndTileMapping.count(Val: i)) {
1182 // Rather than taking the tile size as is, extact the actual constant
1183 // value Attribute where possible, e.g.:
1184 // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1185 auto [_, tileSize] =
1186 getSimplifiedOfrAndStaticSizePair(ofr: dimAndTileMapping[i], b&: rewriter);
1187 tileSizes.push_back(Elt: tileSize);
1188 }
1189 }
1190
1191 // 2. Transpose the input to match the inner tile order:
1192 // %init = tensor.empty()
1193 // %transposed_tile = linalg.transpose ins(%source_or_padded_source),
1194 // outs(%init)
1195 // Assumptions made:
1196 // 1. All outer dims are 1 - the corresponding transposition order doesn't
1197 // matter, but requires all dim indices to be present.
1198 SmallVector<int64_t> srcPermForTranspose;
1199 ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
1200 for (int64_t i = 0; i < srcRank; i++) {
1201 // We assume the `k` dimensions of the inner dim position, where `k` is the
1202 // rank of the inner tiling, correspond to the last `k` indices of the
1203 // transpose permutation. This is done by adding the indices not contained
1204 // in the inner dimension position in order from 0 to `n`. Where n is the
1205 // rank of the source tensor. For example if we have a source tensor with
1206 // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
1207 // indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1208 if (llvm::is_contained(Range&: innerDimPos, Element: i))
1209 continue;
1210 srcPermForTranspose.push_back(Elt: i);
1211 }
1212 srcPermForTranspose.append(in_start: innerDimPos.begin(), in_end: innerDimPos.end());
1213
1214 LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
1215 << "perm: " << llvm::interleaved(srcPermForTranspose)
1216 << "\n");
1217
1218 // 2.1 Create tensor.empty (init value for TransposeOp)
1219 SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
1220 oneIdxAttr);
1221 transShapeForEmptyOp.append(RHS: tileSizes);
1222
1223 applyPermutationToVector<OpFoldResult>(inVec&: transShapeForEmptyOp,
1224 permutation: srcPermForTranspose);
1225 Value empty = rewriter.create<tensor::EmptyOp>(
1226 location: loc, args&: transShapeForEmptyOp, args: packOp.getSourceType().getElementType());
1227
1228 // 2.2 Create linalg.transpose
1229 auto transposedOp = rewriter.create<linalg::TransposeOp>(location: loc, args&: input, args&: empty,
1230 args&: srcPermForTranspose);
1231
1232 // 3. Insert the inner tile to the destination:
1233 // %inserted_tile = tensor.insert_slice(%transposed_tile)
1234 SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1235 SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1236 // Outer dims are all 1s!
1237 SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1238 oneIdxAttr);
1239 SmallVector<int64_t> writeShape;
1240
1241 for (auto tileSize : packOp.getMixedTiles()) {
1242 auto [tileSizeStatic, tileSizeOfr] =
1243 getSimplifiedOfrAndStaticSizePair(ofr: tileSize, b&: rewriter);
1244 writeSizes.push_back(Elt: tileSizeOfr);
1245 writeShape.push_back(Elt: tileSizeStatic);
1246 }
1247
1248 // 4. Replace tensor.packOp with tensor.insert_slice created above
1249 auto insert = rewriter.create<tensor::InsertSliceOp>(
1250 location: loc, args: transposedOp.getResult()[0], args: packOp.getDest(), args&: writeOffsets,
1251 args&: writeSizes, args&: writeStrides);
1252 rewriter.replaceOp(op: packOp, newValues: insert.getResult());
1253
1254 return success();
1255}
1256
1257LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
1258 linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1259 int64_t srcRank = unpackOp.getSourceRank();
1260 int64_t destRank = unpackOp.getDestRank();
1261 ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1262 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1263 if (llvm::any_of(Range: unpackOp.getTiledOuterDims(),
1264 P: [](int64_t dim) { return dim != 1; })) {
1265 return rewriter.notifyMatchFailure(
1266 arg&: unpackOp,
1267 msg: "require the tiled outer dimensions of the result are all 1s");
1268 }
1269
1270 // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1271 // %extracted_tile = tensor.extract_slice(%unpack_op_input)
1272 Location loc = unpackOp.getLoc();
1273 Value source = unpackOp.getSource();
1274 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1275 unpackOp.getDimAndTileMapping();
1276 Attribute zeroIdxAttr = rewriter.getIndexAttr(value: 0);
1277 Attribute oneIdxAttr = rewriter.getIndexAttr(value: 1);
1278
1279 // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1280 // dims:
1281 // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1282 SmallVector<int64_t> readShapeForExtractSlice;
1283 // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1284 // outer-tiled-dims being all 1), this will be
1285 // [ outer-untiled-dims, tile-sizes ]
1286 SmallVector<OpFoldResult> extractSliceSizes;
1287 // The offset and strides attributes for ExtractSliceOp.
1288 SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
1289 SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
1290
1291 // Shape for EmptyOp that's used as the init value for TransposeOp below.
1292 // This should be:
1293 // [ outer-untiled-dims, tile-sizes ]
1294 // However, skip unit dims - TransposeOp (below) applies rank-reduced
1295 // permutation.
1296 SmallVector<OpFoldResult> shapeForEmptyOp;
1297
1298 for (auto i : llvm::seq<unsigned>(Begin: 0, End: destRank)) {
1299 // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1300 //
1301 // As all outer tiled dims are 1, so the corresponding
1302 // slice size to read will also 1. As this will be rank-reducing "extract
1303 // slice" (i.e. the unit dims will be "collapsed"), there's no need to
1304 // update:
1305 // * the output shape for ExtractSliceOp, nor
1306 // * the shape for EmptyOp.
1307 if (dimAndTileMapping.count(Val: i)) {
1308 extractSliceSizes.push_back(Elt: oneIdxAttr);
1309 continue;
1310 }
1311
1312 // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1313 // outer-untiled-dims
1314 if (ShapedType::isDynamic(dValue: srcShape[i])) {
1315 OpFoldResult dynamicDim =
1316 rewriter.create<tensor::DimOp>(location: loc, args&: source, args&: i).getResult();
1317 extractSliceSizes.push_back(Elt: dynamicDim);
1318 shapeForEmptyOp.push_back(Elt: dynamicDim);
1319 } else {
1320 extractSliceSizes.push_back(Elt: rewriter.getIndexAttr(value: srcShape[i]));
1321 if (srcShape[i] != 1)
1322 shapeForEmptyOp.push_back(Elt: rewriter.getIndexAttr(value: srcShape[i]));
1323 }
1324 // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1325 // into account rank-reducing)
1326 if (srcShape[i] != 1) {
1327 readShapeForExtractSlice.push_back(Elt: srcShape[i]);
1328 }
1329 }
1330 // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1331 // shape for EmptyOp.
1332 auto mixedTiles = unpackOp.getMixedTiles();
1333 extractSliceSizes.append(in_start: mixedTiles.begin(), in_end: mixedTiles.end());
1334 shapeForEmptyOp.append(in_start: mixedTiles.begin(), in_end: mixedTiles.end());
1335
1336 // Explicitly create the type for extract_slice op because the inner tile
1337 // size could be 1. We want to represent the whole inner tile in this case.
1338 auto tileShape = srcShape.drop_front(N: destRank);
1339 // Append the inner tile shape to the permuted and rank-reduced outer shape.
1340 readShapeForExtractSlice.append(in_start: tileShape.begin(), in_end: tileShape.end());
1341 Type elemType = unpackOp.getSourceType().getElementType();
1342 auto readType = RankedTensorType::get(shape: readShapeForExtractSlice, elementType: elemType);
1343 Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
1344 location: loc, args&: readType, args: unpackOp.getSource(), args&: extractSliceOffsets,
1345 args&: extractSliceSizes, args&: extractSliceStrides);
1346
1347 // 2. Transpose the tile to match the outer corresponding tile order.
1348 SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
1349 shape: srcShape.take_front(N: destRank), innerDimsPos, outerDimsPerm: unpackOp.getOuterDimsPerm());
1350 // Unpack is a transition out of packed space so we invert the permutation.
1351 perm = invertPermutationVector(permutation: perm);
1352 applyPermutationToVector<OpFoldResult>(inVec&: shapeForEmptyOp, permutation: perm);
1353
1354 Value empty =
1355 rewriter.create<tensor::EmptyOp>(location: loc, args&: shapeForEmptyOp, args&: elemType);
1356 auto transposedOp =
1357 rewriter.create<linalg::TransposeOp>(location: loc, args&: innerTile, args&: empty, args&: perm);
1358
1359 // 3. Handle in-complete tiles if needed. It truncates trailing data from the
1360 // transposed tile.
1361 int numLoops = shapeForEmptyOp.size();
1362 SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
1363 SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
1364 SmallVector<OpFoldResult> tileSizes;
1365 ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1366 for (auto i : llvm::seq<unsigned>(Begin: 0, End: destRank)) {
1367 if (dimAndTileMapping.count(Val: i) || destShape[i] != 1)
1368 tileSizes.push_back(
1369 Elt: tensor::getMixedSize(builder&: rewriter, loc, value: unpackOp.getDest(), dim: i));
1370 }
1371
1372 auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
1373 location: loc, args: transposedOp.getResult()[0], args&: tileOffsets, args&: tileSizes, args&: tileStrides);
1374
1375 // 4. Insert the result to the destination tensor.
1376 SmallVector<OpFoldResult> writeSizes;
1377 SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1378 SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1379 for (int i = 0, idx = 0; i < destRank; ++i) {
1380 if (dimAndTileMapping.count(Val: i) || destShape[i] != 1)
1381 writeSizes.push_back(Elt: tileSizes[idx++]);
1382 else
1383 writeSizes.push_back(Elt: oneIdxAttr);
1384 }
1385 auto insert = rewriter.create<tensor::InsertSliceOp>(
1386 location: loc, args&: partialTile, args: unpackOp.getDest(), args&: writeOffsets, args&: writeSizes,
1387 args&: writeStrides);
1388 rewriter.replaceOp(op: unpackOp, newValues: insert.getResult());
1389
1390 return success();
1391}
1392
1393// The following are patterns for downscaling convolution ops with size-1
1394// window dimensions.
1395//
1396// Note that we'd eventually want to write such transformations in a generic
1397// way, e.g., converting to linalg.generic, removing the size-1 dimensions,
1398// and then turning back to named ops. But for now it's fine to have a few
1399// patterns matching special ops to get started.
1400
1401template <typename Conv2DOp, typename Conv1DOp>
1402FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
1403 returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
1404 if (convOp.hasPureBufferSemantics())
1405 return failure(); // To be implemented.
1406
1407 Value input = convOp.getInputs().front();
1408 Value kernel = convOp.getInputs().back();
1409 Value output = convOp.getOutputs().front();
1410
1411 auto inputType = dyn_cast<RankedTensorType>(Val: input.getType());
1412 auto kernelType = dyn_cast<RankedTensorType>(Val: kernel.getType());
1413 auto outputType = dyn_cast<RankedTensorType>(Val: output.getType());
1414
1415 auto kernelShape = kernelType.getShape();
1416 auto outputShape = outputType.getShape();
1417
1418 // Get domain indices based on conv2D layout.
1419 auto [khIndex, kwIndex, ohIndex, owIndex] =
1420 TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>(
1421 convOp)
1422 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1423 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1424 })
1425 .Case([&](linalg::Conv2DNchwFchwOp op) {
1426 return std::make_tuple(args: 2, args: 3, args: 2, args: 3);
1427 })
1428 .Case([&](linalg::PoolingNhwcSumOp op) {
1429 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1430 })
1431 .Case([&](linalg::PoolingNchwSumOp op) {
1432 return std::make_tuple(args: 0, args: 1, args: 2, args: 3);
1433 })
1434 .Case([&](linalg::PoolingNhwcMaxOp op) {
1435 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1436 })
1437 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1438 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1439 })
1440 .Case([&](linalg::PoolingNhwcMinOp op) {
1441 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1442 })
1443 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1444 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1445 })
1446 .Case([&](linalg::PoolingNchwMaxOp op) {
1447 return std::make_tuple(args: 0, args: 1, args: 2, args: 3);
1448 })
1449 .Default([&](Operation *op) {
1450 llvm_unreachable("unexpected conv2d/pool2d operation.");
1451 return std::make_tuple(args: 0, args: 0, args: 0, args: 0);
1452 });
1453
1454 // Only handle the case where at least one of the window dimensions is
1455 // of size 1. Other cases can rely on tiling to reduce to such cases.
1456 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1457 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1458 bool removeH = (khSize == 1 && ohSize == 1);
1459 bool removeW = (kwSize == 1 && owSize == 1);
1460 if (!removeH && !removeW)
1461 return failure();
1462
1463 // Get new shapes and types for all operands by removing the size-1
1464 // dimension.
1465 using RTTBuilder = RankedTensorType::Builder;
1466 RankedTensorType newInputType =
1467 RTTBuilder(inputType).dropDim(pos: (removeH ? ohIndex : owIndex));
1468 RankedTensorType newKernelType =
1469 RTTBuilder(kernelType).dropDim(pos: (removeH ? khIndex : kwIndex));
1470 RankedTensorType newOutputType =
1471 RTTBuilder(outputType).dropDim(pos: (removeH ? ohIndex : owIndex));
1472
1473 // Rank-reduce operands.
1474 Location loc = convOp.getLoc();
1475 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1476 b&: rewriter, loc, tensor: input, targetType: newInputType);
1477 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1478 b&: rewriter, loc, tensor: kernel, targetType: newKernelType);
1479 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1480 b&: rewriter, loc, tensor: output, targetType: newOutputType);
1481
1482 // Rank-reduce strides and dilations too.
1483 // TODO: dropDim 1-liner helper.
1484 auto strides =
1485 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1486 strides.erase(strides.begin() + (removeH ? 0 : 1));
1487 auto stridesAttr = rewriter.getI64VectorAttr(values: strides);
1488
1489 auto dilations =
1490 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1491 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1492 auto dilationsAttr = rewriter.getI64VectorAttr(values: dilations);
1493
1494 auto conv1DOp = rewriter.create<Conv1DOp>(
1495 loc, newOutputType, ValueRange{newInput, newKernel},
1496 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1497
1498 // Insert back.
1499 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1500 b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output);
1501 rewriter.replaceOp(convOp, inserted);
1502
1503 return conv1DOp;
1504}
1505
1506template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1507 Conv1DNwcWcfOp>;
1508template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1509 Conv1DNcwFcwOp>;
1510template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
1511 PoolingNwcSumOp>;
1512template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
1513 PoolingNcwSumOp>;
1514template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
1515 PoolingNwcMaxOp>;
1516template struct linalg::DownscaleSizeOneWindowed2DConvolution<
1517 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1518template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
1519 PoolingNwcMinOp>;
1520template struct linalg::DownscaleSizeOneWindowed2DConvolution<
1521 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1522template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
1523 PoolingNcwMaxOp>;
1524
1525FailureOr<DepthwiseConv1DNwcWcOp>
1526DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
1527 DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
1528 if (convOp.hasPureBufferSemantics())
1529 return failure(); // To be implemented.
1530
1531 Value input = convOp.getInputs().front();
1532 Value kernel = convOp.getInputs().back();
1533 Value output = convOp.getOutputs().front();
1534
1535 auto inputType = dyn_cast<RankedTensorType>(Val: input.getType());
1536 auto kernelType = dyn_cast<RankedTensorType>(Val: kernel.getType());
1537 auto outputType = dyn_cast<RankedTensorType>(Val: output.getType());
1538
1539 auto kernelShape = kernelType.getShape();
1540 auto outputShape = outputType.getShape();
1541
1542 // Only handle the case where at least one of the window dimensions is
1543 // of size 1. Other cases can rely on tiling to reduce to such cases.
1544 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1545 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1546 bool removeH = (khSize == 1 && ohSize == 1);
1547 bool removeW = (kwSize == 1 && owSize == 1);
1548 if (!removeH && !removeW)
1549 return failure();
1550
1551 // Get new shapes and types for all operands by removing the size-1
1552 // dimension.
1553 using RTTBuilder = RankedTensorType::Builder;
1554 RankedTensorType newInputType =
1555 RTTBuilder(inputType).dropDim(pos: (removeH ? 1 : 2));
1556 RankedTensorType newKernelType =
1557 RTTBuilder(kernelType).dropDim(pos: (removeH ? 0 : 1));
1558 RankedTensorType newOutputType =
1559 RTTBuilder(outputType).dropDim(pos: removeH ? 1 : 2);
1560
1561 // Rank-reduce operands.
1562 Location loc = convOp.getLoc();
1563 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1564 b&: rewriter, loc, tensor: input, targetType: newInputType);
1565 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1566 b&: rewriter, loc, tensor: kernel, targetType: newKernelType);
1567 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1568 b&: rewriter, loc, tensor: output, targetType: newOutputType);
1569
1570 // Rank-reduce strides and dilations too.
1571 // TODO: dropDim 1-liner helper.
1572 auto strides = llvm::to_vector<4>(Range: convOp.getStrides().getValues<int64_t>());
1573 strides.erase(CI: strides.begin() + (removeH ? 0 : 1));
1574 auto stridesAttr = rewriter.getI64VectorAttr(values: strides);
1575
1576 auto dilations =
1577 llvm::to_vector<4>(Range: convOp.getDilations().getValues<int64_t>());
1578 dilations.erase(CI: dilations.begin() + (removeH ? 0 : 1));
1579 auto dilationsAttr = rewriter.getI64VectorAttr(values: dilations);
1580
1581 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1582 location: loc, args&: newOutputType, args: ValueRange{newInput, newKernel},
1583 args: ValueRange{newOutput}, args&: stridesAttr, args&: dilationsAttr);
1584
1585 // Insert back.
1586 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1587 b&: rewriter, loc, tensor: conv1DOp.getResult(i: 0), dest: output);
1588 rewriter.replaceOp(op: convOp, newValues: inserted);
1589
1590 return conv1DOp;
1591}
1592
1593FailureOr<Conv1DOp>
1594DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
1595 PatternRewriter &rewriter) const {
1596 if (convOp.hasPureBufferSemantics())
1597 return failure(); // To be implemented.
1598
1599 Value input = convOp.getInputs().front();
1600 Value kernel = convOp.getInputs().back();
1601 Value output = convOp.getOutputs().front();
1602
1603 auto inputType = dyn_cast<RankedTensorType>(Val: input.getType());
1604 auto kernelType = dyn_cast<RankedTensorType>(Val: kernel.getType());
1605 auto outputType = dyn_cast<RankedTensorType>(Val: output.getType());
1606
1607 auto kernelShape = kernelType.getShape();
1608 auto outputShape = outputType.getShape();
1609
1610 // Only handle the case where at least one of the window dimensions is
1611 // of size 1. Other cases can rely on tiling to reduce to such cases.
1612 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1613 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1614 bool removeH = (khSize == 1 && ohSize == 1);
1615 bool removeW = (kwSize == 1 && owSize == 1);
1616 if (!removeH && !removeW)
1617 return failure();
1618
1619 // Get new shapes and types for all operands by removing the size-1
1620 // dimension.
1621 using RTTBuilder = RankedTensorType::Builder;
1622 RankedTensorType newInputType =
1623 RTTBuilder(inputType).dropDim(pos: (removeH ? 0 : 1));
1624 RankedTensorType newKernelType =
1625 RTTBuilder(kernelType).dropDim(pos: (removeH ? 0 : 1));
1626 RankedTensorType newOutputType =
1627 RTTBuilder(outputType).dropDim(pos: removeH ? 0 : 1);
1628
1629 // Rank-reduce operands.
1630 Location loc = convOp.getLoc();
1631 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1632 b&: rewriter, loc, tensor: input, targetType: newInputType);
1633 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1634 b&: rewriter, loc, tensor: kernel, targetType: newKernelType);
1635 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1636 b&: rewriter, loc, tensor: output, targetType: newOutputType);
1637
1638 auto conv1DOp = rewriter.create<Conv1DOp>(location: loc, args&: newOutputType,
1639 args: ValueRange{newInput, newKernel},
1640 args: ValueRange{newOutput});
1641
1642 // Insert back.
1643 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1644 b&: rewriter, loc, tensor: conv1DOp.getResult(i: 0), dest: output);
1645 rewriter.replaceOp(op: convOp, newValues: inserted);
1646
1647 return conv1DOp;
1648}
1649
1650void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
1651 PatternBenefit benefit) {
1652 patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
1653 Conv1DNwcWcfOp>,
1654 DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
1655 Conv1DNcwFcwOp>,
1656 DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>(
1657 arg: patterns.getContext(), args&: benefit);
1658 patterns.add<
1659 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
1660 DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
1661 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>,
1662 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
1663 PoolingNwcMaxUnsignedOp>,
1664 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>,
1665 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
1666 PoolingNwcMinUnsignedOp>,
1667 DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
1668 arg: patterns.getContext(), args&: benefit);
1669}
1670
1671void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
1672 patterns.add<DecomposeOuterUnitDimsPackOpPattern>(arg: patterns.getContext());
1673 patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(arg: patterns.getContext());
1674}
1675
1676void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
1677 patterns.add<DecomposePadOpPattern>(arg: patterns.getContext());
1678}
1679

source code of mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp