1//===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic and helpers to expose Linalg transforms as rewrite
10// patterns.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Utils/Utils.h"
20#include "mlir/Dialect/SCF/Transforms/Transforms.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
23#include "mlir/Dialect/Tensor/Utils/Utils.h"
24#include "mlir/Dialect/Utils/IndexingUtils.h"
25#include "mlir/Dialect/Utils/StaticValueUtils.h"
26#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
27#include "mlir/Dialect/Vector/IR/VectorOps.h"
28#include "mlir/IR/AffineExpr.h"
29#include "mlir/IR/Matchers.h"
30#include "mlir/Pass/Pass.h"
31#include "mlir/Support/LLVM.h"
32#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33#include "llvm/ADT/ScopeExit.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/raw_ostream.h"
37#include <type_traits>
38#include <utility>
39
40#define DEBUG_TYPE "linalg-transforms"
41
42using namespace mlir;
43using namespace mlir::linalg;
44
45#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46#define DBGSNL() (llvm::dbgs() << "\n")
47
48//===----------------------------------------------------------------------===//
49// Transformations exposed as functional-style API calls.
50//===----------------------------------------------------------------------===//
51
52//===----------------------------------------------------------------------===//
53// peelLoop transformation.
54//===----------------------------------------------------------------------===//
55
56/// Try to peel and canonicalize loop `op` and return the new result.
57/// Also applies affine_min/max bounds simplification on the fly where relevant.
58// TODO: Add support for scf.parallel and affine.for loops.
59SmallVector<Value> mlir::linalg::peelLoop(RewriterBase &rewriter,
60 Operation *op) {
61 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
62 .Case<scf::ForOp>(caseFn: [&](scf::ForOp forOp) {
63 scf::ForOp partialIteration;
64 if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
65 partialIteration)))
66 return partialIteration->getResults();
67 assert(!partialIteration && "expected that loop was not peeled");
68 return forOp->getResults();
69 })
70 .Default(defaultFn: [&](Operation *op) { return op->getResults(); });
71}
72
73/// Peel 'loops' and applies affine_min/max bounds simplification on the fly
74/// where relevant.
75void mlir::linalg::peelLoops(RewriterBase &rewriter,
76 ArrayRef<scf::ForOp> loops) {
77 for (auto loopOp : loops)
78 peelLoop(rewriter, loopOp);
79}
80
81//===----------------------------------------------------------------------===//
82// pack transformation.
83//===----------------------------------------------------------------------===//
84
85#ifndef NDEBUG
86/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
87static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
88 bool found = false;
89 for (AffineExpr e : map.getResults()) {
90 if (!e.isFunctionOfDim(position: dim))
91 continue;
92 if (found)
93 return false;
94 found = true;
95 }
96 return true;
97}
98#endif // NDEBUG
99
100/// Return the index of the first result of `map` that is a function of
101/// AffineDimExpr(dim), std::nullopt otherwise.
102static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
103 int64_t dim) {
104 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
105 AffineExpr expr = map.getResult(idx: i);
106 if (!expr.isFunctionOfDim(position: dim))
107 continue;
108 return i;
109 }
110 return std::nullopt;
111}
112
113/// Perform one step of packing of a LinalgOp's metadata along `dim` into the
114/// `newDim` at `iteratorTypes.size()` by:
115/// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
116/// 2. Appending a `newDim` to the domain of every indexing map.
117/// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing
118/// by potentially adding a `newDim` result to `map`.
119/// The preserved invariant is that `iteratorTypes.size()` is always equal to
120/// `map.getNumDims()` for every map in `indexingMaps`.
121///
122/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
123/// Return a vector that records the optional packing for each operand.
124/// Return failure if the packed indexing cannot be represented with a LinalgOp.
125///
126/// Further details:
127/// ================
128/// The current implementation of packing (i.e. data tiling) consists of
129/// rewriting a linearized strip-mined form into a higher-dimensional access.
130/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
131/// `I` into `4 * i + ii`, where `0 <= ii < 4`.
132/// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
133///
134/// This rewrite into higher dimensional access is not possible for general
135/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
136/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
137/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
138/// The rewrite of the access would be a form not representable in Linalg:
139/// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
140/// Note however that as `J` and `ii` iterate, the accesses do not have a
141/// particular alignment, so packing does not achieve alignment in this case
142///
143/// In the future, we may want to consider a mixed-form that allows some
144/// alignment in the presence of multiple accesses:
145/// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
146/// And would rewrite accesses as:
147/// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
148static FailureOr<SmallVector<std::optional<int64_t>>>
149packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps,
150 SmallVectorImpl<utils::IteratorType> &iteratorTypes,
151 int64_t dim) {
152 int64_t newDim = iteratorTypes.size();
153 iteratorTypes.push_back(iteratorTypes[dim]);
154
155 SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
156 indexingMaps.size(), std::nullopt);
157 SmallVector<AffineMap> newMaps;
158 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
159 ++operandIdx) {
160 AffineMap map = indexingMaps[operandIdx];
161
162 // Add the `newDim` to map whatever the case.
163 assert(map.getNumDims() == newDim && "num dims invariant violation");
164 map = map.shiftDims(shift: 1, offset: newDim);
165
166 // Get the at-most-1 index of the result that is a function of `dim`.
167 // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
168 // logically chunks dimension `dim` into `K * dim + newDim`, where the
169 // packing factor `K` is specified separately.
170 assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
171 "num results invariant violation");
172 auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
173 if (!maybeOperandDimensionToPack.has_value()) {
174 newMaps.push_back(Elt: map);
175 continue;
176 }
177
178 // We can only pack AffineDimExpr atm.
179 if (!isa<AffineDimExpr>(Val: map.getResult(idx: maybeOperandDimensionToPack.value())))
180 return failure();
181
182 // Add `newDim` to the results of the map.
183 map = map.insertResult(expr: Builder(map.getContext()).getAffineDimExpr(position: newDim),
184 pos: map.getNumResults());
185 newMaps.push_back(Elt: map);
186
187 // Record the that `operandIdx` is packed.
188 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
189 }
190 indexingMaps = newMaps;
191
192 return packedDimPerIndexingMap;
193}
194
195namespace {
196
197/// Helper struct to encode packing along one dimension of a LinalgOp.
198struct PackedOperandsDim {
199 OpFoldResult packedSize;
200 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
201};
202
203/// Helper struct to encode packing along all dimensions of a LinalgOp.
204struct PackedOperandsDimList {
205 void pushBack(PackedOperandsDim &&packedOperandsDims) {
206 spec.emplace_back(Args&: packedOperandsDims);
207 }
208 /// Return all the dims that have been packed for operand @ `operandPos`.
209 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
210 /// Return all the pack sizes by which an operand @ `operandPos` is packed.
211 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
212
213private:
214 SmallVector<PackedOperandsDim> spec;
215};
216
217} // namespace
218
219FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
220 tensor::PackOp packOp) {
221 // 1. Filter out NYI cases.
222 auto packedTensorType =
223 cast<RankedTensorType>(packOp->getResultTypes().front());
224 if (llvm::any_of(packOp.getStaticInnerTiles(),
225 [](int64_t size) { return ShapedType::isDynamic(size); })) {
226 return rewriter.notifyMatchFailure(
227 packOp,
228 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
229 }
230
231 Location loc = packOp->getLoc();
232 OpBuilder::InsertionGuard g(rewriter);
233 rewriter.setInsertionPoint(packOp);
234
235 // 2. Compute the permutation vector to shuffle packed shape into the shape
236 // before any outer or inner permutations have been applied.
237 PackingMetadata packingMetadata = computePackingMetadata(
238 packedTensorType.getRank(), packOp.getInnerDimsPos());
239 SmallVector<int64_t> packedToStripMinedShapePerm =
240 tensor::getPackInverseDestPerm(packOp);
241
242 // 3. Compute the stripMinedShape: this is the packed shape before any outer
243 // or inner permutations have been applied.
244 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
245 applyPermutationToVector(inVec&: stripMinedShape, permutation: packedToStripMinedShapePerm);
246
247 // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
248 SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
249 rewriter.getIndexAttr(0));
250 SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
251 rewriter.getIndexAttr(0));
252 for (auto [pos, innerSize] :
253 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
254 int outerPos =
255 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
256 OpFoldResult origSize =
257 tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
258 OpFoldResult outerSize =
259 tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
260 AffineExpr s0, d0, d1;
261 bindDims(rewriter.getContext(), d0, d1);
262 bindSymbols(rewriter.getContext(), s0);
263 auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
264 highs[pos] = affine::makeComposedFoldedAffineApply(
265 rewriter, loc, map, {outerSize, origSize, innerSize});
266 }
267 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
268 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
269 packingMetadata.reassociations);
270 Value paddingValue = packOp.getPaddingValue();
271 if (!paddingValue) {
272 paddingValue = rewriter.create<arith::ConstantOp>(
273 loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
274 }
275 auto padOp =
276 rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
277 highs, paddingValue, /*nofold=*/false);
278
279 LLVM_DEBUG(
280 DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
281 DBGS() << "insertPositions: ");
282 DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
283 DBGS() << "outerPositions: ");
284 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
285 DBGS() << "packedShape: ");
286 DBGSNL();
287 llvm::interleaveComma(packedToStripMinedShapePerm,
288 DBGS() << "packedToStripMinedShapePerm: ");
289 DBGSNL(); llvm::interleaveComma(
290 packingMetadata.reassociations, DBGS() << "reassociations: ",
291 [&](ReassociationIndices ri) {
292 llvm::interleaveComma(ri, llvm::dbgs() << "|");
293 });
294 DBGSNL();
295 llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
296 DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
297
298 if (packOp.isLikePad()) {
299 // Pack ops which operate as simple pads may not produce legal
300 // tensor.insert_slice operations when the packed type does not rank reduce
301 // to the padded type.
302 SliceVerificationResult rankReduces =
303 isRankReducedType(packedTensorType, padOp.getResultType());
304
305 if (rankReduces == SliceVerificationResult::Success) {
306 // This pack is just a plain pad.
307 // Just insert the pad in the higher ranked tensor.
308 auto emptyOp =
309 rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
310 // Offsets.
311 SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
312 rewriter.getIndexAttr(0));
313 // Strides.
314 SmallVector<OpFoldResult> ones(packOp.getDestRank(),
315 rewriter.getIndexAttr(1));
316 SmallVector<OpFoldResult> sizes =
317 tensor::getMixedSizes(builder&: rewriter, loc, value: packOp.getDest());
318
319 auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
320 loc, /*source=*/padOp, /*dest=*/emptyOp,
321 /*offsets=*/zeros, sizes,
322 /*strides=*/ones);
323
324 LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
325
326 rewriter.replaceOp(packOp, insertSliceOp->getResults());
327
328 return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
329 /*transposeOp=*/nullptr};
330 }
331 }
332 // 5. Expand from the padded result to the stripMinedShape.
333 auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
334 loc,
335 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
336 padOp.getResult(), packingMetadata.reassociations);
337
338 // 6. Transpose stripMinedShape to packedShape.
339 SmallVector<int64_t> transpPerm =
340 invertPermutationVector(permutation: packedToStripMinedShapePerm);
341 auto transposeOp = rewriter.create<linalg::TransposeOp>(
342 loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
343
344 LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
345 DBGS() << "reshape op: " << reshapeOp; DBGSNL();
346 llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
347 DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
348
349 // 7. Replace packOp by transposeOp.
350 rewriter.replaceOp(packOp, transposeOp->getResults());
351
352 return LowerPackResult{padOp, reshapeOp, transposeOp};
353}
354
355FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
356 tensor::UnPackOp unPackOp) {
357 // 1. Filter out NYI cases.
358 if (!unPackOp.getOuterDimsPerm().empty() &&
359 !isIdentityPermutation(unPackOp.getOuterDimsPerm())) {
360 return rewriter.notifyMatchFailure(unPackOp,
361 "non-identity outer dims perm NYI");
362 }
363
364 Location loc = unPackOp->getLoc();
365 OpBuilder::InsertionGuard g(rewriter);
366 rewriter.setInsertionPoint(unPackOp);
367
368 RankedTensorType packedTensorType = unPackOp.getSourceType();
369 int64_t packedRank = packedTensorType.getRank();
370
371 OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
372 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
373 if (unPackOp.isLikeUnPad()) {
374 // This unpack is just a plain unpad.
375 // Just extract the slice from the higher ranked tensor.
376 ArrayRef<int64_t> destShape = destTensorType.getShape();
377 // The inner dimensions stay the same as the destination tensor, but the
378 // outer ones are additional 1s.
379 SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
380 sizes.append(tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getDest()));
381
382 auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
383 loc, destTensorType, unPackOp.getSource(),
384 SmallVector<OpFoldResult>(packedRank, zero), sizes,
385 SmallVector<OpFoldResult>(packedRank, one));
386
387 rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
388
389 return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
390 /*reshapeOp=*/nullptr, extractSliceOp};
391 }
392 // 2. Compute the permutation vector to move the last `numPackedDims` into
393 // the `innerPosDims` of a shape of rank `packedRank`.
394 int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
395 auto lastDims = llvm::to_vector(
396 Range: llvm::seq<int64_t>(Begin: packedRank - numPackedDims, End: packedRank));
397 PackingMetadata packingMetadata =
398 computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
399 SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
400 packedRank, lastDims, packingMetadata.insertPositions);
401
402 // 3. Compute the stripMinedShape: this is the packed shape without outer and
403 // inner permutations.
404 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
405 applyPermutationToVector(inVec&: stripMinedShape, permutation: lastDimsToInsertPositionsPerm);
406
407 // 4. Transpose packedShape to stripMinedShape.
408 RankedTensorType stripMinedTensorType =
409 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
410 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
411 stripMinedTensorType, packingMetadata.reassociations);
412
413 // Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm
414 // permutation.
415 SmallVector<OpFoldResult, 4> dims =
416 tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getSource());
417 applyPermutationToVector(inVec&: dims, permutation: lastDimsToInsertPositionsPerm);
418 auto emptyOp = rewriter.create<tensor::EmptyOp>(
419 loc, dims, stripMinedTensorType.getElementType());
420 auto transposeOp = rewriter.create<linalg::TransposeOp>(
421 loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
422
423 LLVM_DEBUG(
424 DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
425 DBGS() << "insertPositions: ");
426 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
427 DBGS() << "packedShape: ");
428 DBGSNL();
429 llvm::interleaveComma(lastDimsToInsertPositionsPerm,
430 DBGS() << "lastDimsToInsertPositionsPerm: ");
431 DBGSNL(); llvm::interleaveComma(
432 packingMetadata.reassociations, DBGS() << "reassociations: ",
433 [&](ReassociationIndices ri) {
434 llvm::interleaveComma(ri, llvm::dbgs() << "|");
435 });
436 DBGSNL();
437 llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
438 DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
439
440 // 5. Collapse from the stripMinedShape to the padded result.
441 auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
442 loc, collapsedType, transposeOp->getResult(0),
443 packingMetadata.reassociations);
444
445 // 6. ExtractSlice.
446 int64_t destRank = destTensorType.getRank();
447 auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
448 loc, destTensorType, reshapeOp->getResult(0),
449 SmallVector<OpFoldResult>(destRank, zero),
450 tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getDest()),
451 SmallVector<OpFoldResult>(destRank, one));
452
453 // 7. Inject a copy to preserve DPS.
454 auto copyOp = rewriter.create<linalg::CopyOp>(
455 loc, extractSliceOp->getResult(0), unPackOp.getDest());
456
457 // 8. Replace unPackOp by extractSliceOp.
458 rewriter.replaceOp(unPackOp, copyOp->getResults());
459
460 return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
461}
462
463SmallVector<int64_t>
464PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
465 SmallVector<int64_t> res;
466 for (auto &i : spec) {
467 if (!i.packedDimForEachOperand[operandPos].has_value())
468 continue;
469 res.push_back(Elt: i.packedDimForEachOperand[operandPos].value());
470 }
471 return res;
472}
473
474SmallVector<OpFoldResult>
475PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
476 SmallVector<OpFoldResult> res;
477 for (auto &i : spec) {
478 if (!i.packedDimForEachOperand[operandPos].has_value())
479 continue;
480 res.push_back(Elt: i.packedSize);
481 }
482 return res;
483}
484
485/// Implement packing of a single LinalgOp by performing packing by
486/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
487/// Return the packed Linalg op on success, failure otherwise.
488FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
489 linalg::LinalgOp linalgOp,
490 ArrayRef<OpFoldResult> packedSizes) {
491 if (packedSizes.size() != linalgOp.getNumLoops()) {
492 return rewriter.notifyMatchFailure(linalgOp,
493 "incorrect number of pack sizes");
494 }
495
496 Location loc = linalgOp->getLoc();
497 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
498 SmallVector<utils::IteratorType> iteratorTypes =
499 linalgOp.getIteratorTypesArray();
500 LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n";
501 llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
502 llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
503 DBGSNL(););
504
505 SmallVector<tensor::PackOp> packOps;
506 SmallVector<tensor::UnPackOp> unPackOps;
507 // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
508 PackedOperandsDimList listOfPackedOperandsDim;
509 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
510 std::optional<int64_t> maybeConstant = getConstantIntValue(ofr: packedSizes[i]);
511 // Skip tile sizes explicitly set to 0.
512 if (maybeConstant.has_value() && maybeConstant.value() == 0)
513 continue;
514
515 PackedOperandsDim packedOperandsDims;
516 packedOperandsDims.packedSize = packedSizes[i];
517 FailureOr<SmallVector<std::optional<int64_t>>>
518 maybePackedDimForEachOperand =
519 packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
520 if (failed(result: maybePackedDimForEachOperand))
521 return failure();
522 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
523 listOfPackedOperandsDim.pushBack(packedOperandsDims: std::move(packedOperandsDims));
524
525 LLVM_DEBUG(
526 DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
527 << "\n";
528 llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
529 llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();
530 llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
531 DBGS() << "packedDimForEachOperand: ");
532 DBGSNL(););
533 }
534
535 // Step 2. Propagate packing to all LinalgOp operands.
536 SmallVector<Value> inputsAndInits, results;
537 SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range(
538 linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
539 SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
540 for (const auto &operandsList : {inputOperands, initOperands}) {
541 for (OpOperand *opOperand : operandsList) {
542 int64_t pos = opOperand->getOperandNumber();
543 Value operand = opOperand->get();
544 SmallVector<int64_t> innerPos =
545 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
546 SmallVector<OpFoldResult> innerPackSizes =
547 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
548 LLVM_DEBUG(
549 DBGS() << "operand: " << operand << "\n";
550 llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL();
551 llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: ");
552 DBGSNL(););
553 if (innerPackSizes.empty()) {
554 inputsAndInits.push_back(operand);
555 continue;
556 }
557 Value dest = tensor::PackOp::createDestinationTensor(
558 rewriter, loc, operand, innerPackSizes, innerPos,
559 /*outerDimsPerm=*/{});
560 ShapedType operandType = cast<ShapedType>(operand.getType());
561 bool areConstantTiles =
562 llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
563 return getConstantIntValue(tile).has_value();
564 });
565 if (areConstantTiles && operandType.hasStaticShape() &&
566 !tensor::PackOp::requirePaddingValue(
567 operandType.getShape(), innerPos,
568 cast<ShapedType>(dest.getType()).getShape(), {},
569 innerPackSizes)) {
570 packOps.push_back(rewriter.create<tensor::PackOp>(
571 loc, operand, dest, innerPos, innerPackSizes));
572 } else {
573 // TODO: value of the padding attribute should be determined by
574 // consumers.
575 auto zeroAttr =
576 rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
577 Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
578 packOps.push_back(rewriter.create<tensor::PackOp>(
579 loc, operand, dest, innerPos, innerPackSizes, zero));
580 }
581 inputsAndInits.push_back(packOps.back());
582 }
583 }
584
585 // Step 3. Build the packed op, use the type of `inits` as result types.
586 ValueRange inputs =
587 ValueRange{inputsAndInits}.take_front(n: linalgOp.getNumDpsInputs());
588 ValueRange inits =
589 ValueRange{inputsAndInits}.take_back(n: linalgOp.getNumDpsInits());
590 auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
591 linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
592 iteratorTypes);
593 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
594
595 // Step 4. Propagate packing to all the op results.
596 for (OpResult result : packedLinalgOp->getResults()) {
597 int64_t resultNum = result.getResultNumber();
598 tensor::PackOp maybePackedInit =
599 inits[resultNum].getDefiningOp<tensor::PackOp>();
600 if (!maybePackedInit) {
601 results.push_back(result);
602 continue;
603 }
604 // Build the symmetrical UnPackOp to the existing PackOp.
605 unPackOps.push_back(rewriter.create<tensor::UnPackOp>(
606 packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
607 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
608 results.push_back(unPackOps.back());
609 }
610
611 // Step 5. Replace `linalgOp`.
612 rewriter.replaceOp(linalgOp, results);
613
614 // Return packedLinalgOp.
615 return PackResult{packOps,
616 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
617 unPackOps};
618}
619
620//===----------------------------------------------------------------------===//
621// packTranspose transformation.
622//===----------------------------------------------------------------------===//
623
624/// Return a copy of `tensorType` after permutation by `permutationVector`.
625// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
626// but this would introduce a dependence on Dialect in IR.
627// TODO: Restructure.
628static RankedTensorType permuteShape(RankedTensorType tensorType,
629 ArrayRef<int64_t> permutationVector) {
630 SmallVector<int64_t> shape(tensorType.getShape());
631 applyPermutationToVector(inVec&: shape, permutation: permutationVector);
632 return RankedTensorType::Builder(tensorType).setShape(shape);
633}
634
635/// Return a new GenericOp obtained by transposing opOperand by the permutation
636/// vector:
637/// - the corresponding indexing map is transposed by `permutation`
638/// - the corresponding operand value is replaced by `transposedValue`
639/// `linalgOp` is replaced by the return op in the process.
640/// Asserts that `transposedValue` is of the proper transposed ShapedType.
641static LinalgOp transposeOneLinalgOperandAndReplace(
642 RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
643 ArrayRef<int64_t> permutation, Value transposedValue) {
644 // Sanity check the operand.
645 assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
646
647 // Sanity check of the expected transposed tensor type.
648 auto tensorType = permuteShape(
649 cast<RankedTensorType>(opOperand.get().getType()), permutation);
650 (void)tensorType;
651 assert(tensorType == transposedValue.getType() &&
652 "expected tensor type mismatch");
653
654 // Compute the transposed indexing map.
655 // Sigh unsigned pollution.
656 SmallVector<unsigned> tmpTransposition = llvm::to_vector(
657 Range: llvm::map_range(C&: permutation, F: [](int64_t i) -> unsigned { return i; }));
658 AffineMap permutationMap =
659 AffineMap::getPermutationMap(permutation: tmpTransposition, context: rewriter.getContext());
660 AffineMap transposedMap =
661 permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
662
663 // Set the transposed indexing map in the proper position.
664 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
665 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
666 // Set the transposedValue in the proper operand position.
667 SmallVector<Value> operands = linalgOp->getOperands();
668 operands[opOperand.getOperandNumber()] = transposedValue;
669
670 ValueRange operandsRef(operands);
671 auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
672 /*location=*/linalgOp->getLoc(),
673 /*resultTensorTypes=*/
674 operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
675 /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
676 /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
677 /*indexingMaps=*/indexingMaps,
678 /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
679 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
680 rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
681
682 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
683}
684
685FailureOr<PackTransposeResult>
686linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
687 linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
688 ArrayRef<int64_t> outerPerm,
689 ArrayRef<int64_t> innerPerm) {
690 Location loc = linalgOp.getLoc();
691
692 // Step 1. Transpose packOp.
693 rewriter.setInsertionPoint(packOp);
694 tensor::PackOp transposedPackOp =
695 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
696
697 if (!packOp.getResult().hasOneUse())
698 return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
699
700 OpOperand &packUse = *packOp->getUses().begin();
701 if (packUse.getOwner() != linalgOp) {
702 return rewriter.notifyMatchFailure(
703 linalgOp, "not a single use by the LinalgOp target");
704 }
705 if (maybeUnPackOp &&
706 (!linalgOp.isDpsInit(&packUse) ||
707 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
708 return rewriter.notifyMatchFailure(linalgOp,
709 "not produced by the LinalgOp target");
710 }
711
712 // Step 2. Transpose linalgOp.
713 // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
714 // identity. Don't rely on it.
715 int64_t numLeadingDims = packOp.getSourceRank();
716 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
717 // Step 2.a. Compute the permutation on the whole operand.
718 // Leading part just reuse the outerPerm.
719 SmallVector<int64_t> permutation(outerPerm);
720 if (permutation.empty())
721 llvm::append_range(C&: permutation, R: llvm::seq<int64_t>(Begin: 0, End: numLeadingDims));
722 // Trailing part needs to reindex positions by `numLeadingDims`.
723 if (innerPerm.empty()) {
724 llvm::append_range(
725 C&: permutation,
726 R: llvm::seq<int64_t>(Begin: numLeadingDims, End: numLeadingDims + numTrailingDims));
727 } else {
728 llvm::append_range(permutation,
729 llvm::map_range(innerPerm, [&](int64_t pos) {
730 return numLeadingDims + pos;
731 }));
732 }
733 if (!isPermutationVector(interchange: permutation))
734 return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
735
736 // Step 2.b. Save the transposedPackUse operand number in case we need to
737 // get the tied OpResult after `linalgOp` has been replaced.
738 int64_t packUseOperandNumber = packUse.getOperandNumber();
739 // Step 2.c. Actually perform the transposition.
740 rewriter.setInsertionPoint(linalgOp);
741 linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
742 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
743
744 // Step 3. Maybe transpose unPackOp.
745 tensor::UnPackOp transposedUnPackOp;
746 if (maybeUnPackOp) {
747 OpOperand &opOperand =
748 transposedLinalgOp->getOpOperand(packUseOperandNumber);
749 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
750 rewriter.setInsertionPoint(maybeUnPackOp);
751 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
752 rewriter, loc, transposedResult, innerPerm, outerPerm);
753
754 rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
755 }
756
757 // Step 4. Finally, replace packOp now that we don't need it anymore.
758 rewriter.replaceOp(packOp, transposedPackOp->getResults());
759
760 return PackTransposeResult{transposedPackOp, transposedLinalgOp,
761 transposedUnPackOp};
762}
763
764//===----------------------------------------------------------------------===//
765// packMatmulGreedily transformation.
766//===----------------------------------------------------------------------===//
767
768/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
769/// and n are proper parallel dimensions and k is a proper reduction
770/// dimension. Packing occurs by rewriting the op as a linalg.generic and
771/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
772/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
773/// to reorder {m, n, k} into one of the 8 possible forms. The outer
774/// dimensions of the operands are not permuted at this time, this is left for
775/// future work.
776FailureOr<PackResult>
777linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
778 ArrayRef<OpFoldResult> mnkPackedSizes,
779 ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
780 ArrayRef<int64_t> mnkOrder) {
781 assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
782 assert((mnkPaddedSizesNextMultipleOf.empty() ||
783 mnkPaddedSizesNextMultipleOf.size() == 3) &&
784 "num of packing sizes next multiple should be empty or of size 3");
785 assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
786 assert(isPermutationVector(mnkOrder) && "expected a permutation");
787
788 int64_t numLoops = linalgOp.getNumLoops();
789 if (numLoops <= 2) {
790 LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
791 << numLoops << "\nin: " << linalgOp << "\n");
792 return rewriter.notifyMatchFailure(
793 linalgOp, "need 3+ loops to find a matmul to pack");
794 }
795
796 // Locally adjust the desired iterator position of mnk and packing sizes.
797 int64_t numPackedDims = mnkPackedSizes.size();
798 SmallVector<int64_t> mmnnkkPos(numPackedDims);
799 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
800 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
801 SmallVector<OpFoldResult> packedSizes(numPackedDims);
802 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
803 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
804 SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
805 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
806 paddedSizesNextMultipleOf[mnkOrder[i]] =
807 mnkPaddedSizesNextMultipleOf.empty() ? 0
808 : mnkPaddedSizesNextMultipleOf[i];
809 }
810
811 // 1. Infer dims that are important for matmul.
812 FailureOr<ContractionDimensions> maybeDimensions =
813 inferContractionDims(linalgOp);
814 if (failed(result: maybeDimensions)) {
815 LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
816 << "\n");
817 return rewriter.notifyMatchFailure(linalgOp,
818 "couldn't infer matmul iterators");
819 }
820
821 // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
822 // minor iterators. In cases with multiple options for m, n, k bias towards
823 // the most minor embedding.
824 // If we wanted a different normalization order, this is where it would have
825 // to plug a heuristic.
826 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
827 kPos = maybeDimensions->k.back();
828 LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
829 DBGS() << "Start packing generic op greedily with (m@" << mPos
830 << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
831 << "\n";);
832
833 // 2.a. Rewrite as a generic.
834 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
835 if (!genericOp) {
836 FailureOr<GenericOp> generalizeResult =
837 generalizeNamedOp(rewriter, linalgOp);
838 assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
839 genericOp = *generalizeResult;
840 }
841
842 // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
843 // iterators. Note that this only normalized the iteration order and does
844 // not change the indexings of any operand.
845 SmallVector<int64_t> permutation =
846 computePermutationVector(permSize: numLoops, positions: {mPos, nPos, kPos}, desiredPositions: mmnnkkPos);
847 LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
848 // Sign .. unsigned pollution.
849 SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
850 FailureOr<GenericOp> interchangeResult =
851 interchangeGenericOp(rewriter, genericOp, unsignedPerm);
852 assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
853 genericOp = *interchangeResult;
854 LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
855
856 // At this point, the op iterators are normalized to {leading, k, m, n}.
857 // The layouts induced by packing will always be:
858 // - LHS{leading_lhs, kk, mm}
859 // - RHS{leading_rhs, kk, nn}
860 // - RES{leading_res, mm, nn}
861 // If we wanted to change the packed order, we would reorder (k, m, n) to
862 // something else above.
863 //
864 // Additional permutations of the outer dims of the operands (i.e.
865 // leading_lhs, leading_rhs and leading_res) could follow by computing the
866 // desired outerPerm for each operand.
867 // This is left for future work.
868
869 // TODO: this creates too much IR, go use reifyResultShapes.
870 SmallVector<Range, 4> loopRanges =
871 cast<LinalgOp>(genericOp.getOperation())
872 .createLoopRanges(rewriter, genericOp.getLoc());
873
874 // Add leading zeros to match numLoops, we only pack the last 3 dimensions
875 // post interchange.
876 LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
877 DBGS() << "paddedSizesNextMultipleOf: ");
878 DBGSNL(););
879 LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ",
880 [](Range r) { llvm::dbgs() << r.size; });
881 DBGSNL(););
882 SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
883 rewriter.getIndexAttr(0));
884 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
885 if (paddedSizesNextMultipleOf[i] == 0) {
886 adjustedPackedSizes.push_back(Elt: packedSizes[i]);
887 continue;
888 }
889 AffineExpr d0, s0;
890 bindDims(ctx: rewriter.getContext(), exprs&: d0);
891 bindSymbols(ctx: rewriter.getContext(), exprs&: s0);
892 adjustedPackedSizes.push_back(Elt: affine::makeComposedFoldedAffineApply(
893 rewriter, genericOp->getLoc(), d0.ceilDiv(other: s0) * s0,
894 {loopRanges[adjustedPackedSizes.size()].size,
895 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
896 }
897 LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
898 DBGS() << "adjustedPackedSizes: ");
899 DBGSNL(););
900
901 // TODO: If we wanted to give the genericOp a name after packing, after
902 // calling `pack` would be a good time. One would still need to check that
903 // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
904 // also allow degenerate matmul cases (i.e. matvec, dot).
905 return pack(rewriter, genericOp, adjustedPackedSizes);
906}
907
908//===----------------------------------------------------------------------===//
909// Transformations exposed as rewrite patterns.
910//===----------------------------------------------------------------------===//
911
912LinalgTilingOptions &
913mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
914 assert(!tileSizeComputationFunction && "tile sizes already set");
915 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
916 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
917 OpBuilder::InsertionGuard guard(b);
918 b.setInsertionPointToStart(
919 &op->getParentOfType<func::FuncOp>().getBody().front());
920 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
921 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
922 return v;
923 }));
924 };
925 return *this;
926}
927
928LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
929 memref::CopyOp copyOp, PatternRewriter &rewriter) const {
930 return vectorizeCopy(rewriter, copyOp);
931}
932
933/// Filling `dest` using FillOp constant padding value if possible.
934/// Otherwise, generate a tensor::GenerateOp.
935Value GeneralizePadOpPattern::createFillOrGenerateOp(
936 RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
937 const SmallVector<Value> &dynSizes) const {
938 auto padValue = padOp.getConstantPaddingValue();
939 if (padValue)
940 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
941
942 // Fill could not be optimized: Lower to tensor::GenerateOp with region.
943 auto generateOp = rewriter.create<tensor::GenerateOp>(
944 padOp.getLoc(), padOp.getResultType(), dynSizes);
945 // Copy region to new op.
946 IRMapping bvm;
947 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
948 return generateOp;
949}
950
951LogicalResult
952GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
953 PatternRewriter &rewriter) const {
954 // Given an OpFoldResult, return an index-typed value.
955 auto getIdxValue = [&](OpFoldResult ofr) {
956 if (auto val = llvm::dyn_cast_if_present<Value>(Val&: ofr))
957 return val;
958 return rewriter
959 .create<arith::ConstantIndexOp>(
960 padOp.getLoc(), cast<IntegerAttr>(ofr.get<Attribute>()).getInt())
961 .getResult();
962 };
963
964 auto resultType = padOp.getResultType();
965 // Compute size of EmptyOp. Any combination of static/dynamic is supported.
966 SmallVector<Value> dynSizes;
967 SmallVector<int64_t> staticSizes;
968 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
969 if (resultType.isDynamicDim(dim)) {
970 auto srcSize = getIdxValue(tensor::getMixedSize(builder&: rewriter, loc: padOp.getLoc(),
971 value: padOp.getSource(), dim));
972 // Add low and high padding value.
973 auto plusLow = rewriter.createOrFold<arith::AddIOp>(
974 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
975 auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
976 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
977 dynSizes.push_back(Elt: plusHigh);
978 }
979 staticSizes.push_back(Elt: resultType.getDimSize(dim));
980 }
981
982 // Init tensor and fill it with padding.
983 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
984 padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
985 Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
986
987 // Try optimize the copy of source.
988 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
989 return success();
990
991 // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
992 // for copying the PadOp source.
993 auto sourceType = padOp.getSourceType();
994 // Compute size of source of tensor::PadOp.
995 SmallVector<OpFoldResult> srcSizes =
996 tensor::getMixedSizes(builder&: rewriter, loc: padOp.getLoc(), value: padOp.getSource());
997 // Strides of InsertSliceOp are all 1.
998 SmallVector<OpFoldResult> strides(sourceType.getRank(),
999 rewriter.getIndexAttr(1));
1000 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
1001 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
1002 strides);
1003
1004 return success();
1005}
1006
1007LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
1008 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
1009 if (!sliceOp.hasUnitStride())
1010 return failure();
1011
1012 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1013 if (!padOp)
1014 return failure();
1015
1016 bool zeroSliceGuard = true;
1017 if (controlFn) {
1018 if (std::optional<bool> control = controlFn(sliceOp))
1019 zeroSliceGuard = *control;
1020 else
1021 return failure();
1022 }
1023
1024 FailureOr<TilingResult> tilingResult =
1025 tensor::bubbleUpPadSlice(b&: rewriter, padOp: padOp, offsets: sliceOp.getMixedOffsets(),
1026 sizes: sliceOp.getMixedSizes(), generateZeroSliceGuard: zeroSliceGuard);
1027 if (failed(result: tilingResult))
1028 return failure();
1029 // All shapes are static and the data source is actually used. Rewrite into
1030 // pad(extract_slice(x)).
1031 rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1032 return success();
1033}
1034
1035/// Returns a tensor.pad op if padding value is set. Otherwise, returns the
1036/// source directly. The method assumes that the `packOp` has static shapes.
1037static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
1038 tensor::PackOp packOp) {
1039 Value input = packOp.getSource();
1040 if (!packOp.getPaddingValue()) {
1041 return input;
1042 }
1043
1044 Location loc = packOp.getLoc();
1045 ShapedType inputType = packOp.getSourceType();
1046 int64_t inputRank = inputType.getRank();
1047 assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
1048 [](int64_t val) { return val == 1; }));
1049
1050 SmallVector<int64_t> paddedShape;
1051 DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1052 packOp.getDimAndTileMapping();
1053 for (int64_t dim = 0; dim < inputRank; ++dim) {
1054 int64_t size = inputType.getDimSize(dim);
1055 if (!tileAndPosMapping.count(Val: dim)) {
1056 paddedShape.push_back(Elt: size);
1057 continue;
1058 }
1059
1060 // The size is less than or equal to tileSize because outer dims are all 1s.
1061 std::optional<int64_t> tileSize =
1062 getConstantIntValue(ofr: tileAndPosMapping.lookup(Val: dim));
1063 assert(tileSize.has_value() && "dynamic inner tile size is not supported");
1064 paddedShape.push_back(Elt: tileSize.value());
1065 }
1066 auto resultType =
1067 RankedTensorType::get(paddedShape, inputType.getElementType());
1068 return tensor::createPadHighOp(type: resultType, source: input, pad: packOp.getPaddingValue(),
1069 /*nofold=*/false, loc, builder);
1070}
1071
1072// Normalizes a permutation on a higher rank space to its actual size, e.g.
1073// perm = [1, 4, 2]
1074// becomes
1075// norm = [0, 2, 1]
1076static SmallVector<int64_t>
1077getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) {
1078 constexpr int64_t kNonTiledMarker = -1;
1079 SmallVector<int64_t> vec(rank, kNonTiledMarker);
1080 for (auto [index, value] : llvm::enumerate(First&: perm))
1081 vec[value] = index;
1082 SmallVector<int64_t> normalizedPerm = llvm::to_vector(Range: llvm::make_filter_range(
1083 Range&: vec, Pred: [&](int64_t v) { return v != kNonTiledMarker; }));
1084 // This inverts the permutation in addition to normalizing so invert back.
1085 return invertPermutationVector(permutation: normalizedPerm);
1086}
1087
1088// Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1089// assuming rank reduction of unit outer dims.
1090static SmallVector<int64_t>
1091getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
1092 ArrayRef<int64_t> innerDimsPos,
1093 ArrayRef<int64_t> outerDimsPerm) {
1094 SmallVector<int64_t> rankReducedOuterDimsPerm;
1095 SmallVector<int64_t> outerDims;
1096 SmallVector<int64_t> innerDims;
1097 int64_t dim = 0;
1098 int64_t unpackedRank = shape.size();
1099 for (auto i : llvm::seq<unsigned>(Begin: 0, End: unpackedRank)) {
1100 if (llvm::is_contained(Range&: innerDimsPos, Element: i)) {
1101 innerDims.push_back(Elt: dim++);
1102 continue;
1103 }
1104 if (shape[i] == 1)
1105 continue;
1106 outerDims.push_back(Elt: dim++);
1107 if (!outerDimsPerm.empty())
1108 rankReducedOuterDimsPerm.push_back(Elt: outerDimsPerm[i]);
1109 }
1110
1111 // Get the position of the inner dims after permutation.
1112 SmallVector<int64_t> innerPerm =
1113 getPackUnpackNormalizedPerm(rank: unpackedRank, perm: innerDimsPos);
1114 applyPermutationToVector<int64_t>(inVec&: innerDims, permutation: innerPerm);
1115
1116 // Ditto for the outer dims.
1117 SmallVector<int64_t> perm = outerDims;
1118
1119 rankReducedOuterDimsPerm =
1120 getPackUnpackNormalizedPerm(rank: unpackedRank, perm: rankReducedOuterDimsPerm);
1121 if (!rankReducedOuterDimsPerm.empty())
1122 applyPermutationToVector<int64_t>(inVec&: perm, permutation: rankReducedOuterDimsPerm);
1123
1124 // The tile always ends up as the inner most dims after packing.
1125 perm.append(RHS: innerDims);
1126
1127 return perm;
1128}
1129
1130LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1131 tensor::PackOp packOp, PatternRewriter &rewriter) const {
1132 if (llvm::any_of(packOp.getMixedTiles(),
1133 [](OpFoldResult tile) { return tile.is<Value>(); })) {
1134 return rewriter.notifyMatchFailure(packOp,
1135 "require inner tile sizes being static");
1136 }
1137
1138 // TODO: support the case that outer dimensions are not all 1s. A
1139 // tensor.expand_shape will be generated in this case.
1140 auto innerDimsPos = packOp.getInnerDimsPos();
1141 int64_t srcRank = packOp.getSourceRank();
1142 auto destShape = packOp.getDestType().getShape();
1143 if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
1144 return destShape[index] != 1;
1145 })) {
1146 return rewriter.notifyMatchFailure(
1147 packOp, "require the tiled outer dimensions of the result are all 1s");
1148 }
1149
1150 // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
1151 // outer dims.
1152 Location loc = packOp.getLoc();
1153 Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1154 auto inputShape = packOp.getSourceType().getShape();
1155 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1156 packOp.getDimAndTileMapping();
1157 Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1158 Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1159 SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
1160 SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
1161 SmallVector<OpFoldResult> readSizes;
1162 SmallVector<int64_t> readShape;
1163 for (auto i : llvm::seq<unsigned>(0, srcRank)) {
1164 if (dimAndTileMapping.count(i)) {
1165 readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
1166 .value_or(ShapedType::kDynamic));
1167 readSizes.push_back(dimAndTileMapping[i]);
1168 continue;
1169 }
1170 if (ShapedType::isDynamic(inputShape[i])) {
1171 readSizes.push_back(
1172 rewriter.create<tensor::DimOp>(loc, input, i).getResult());
1173 } else {
1174 readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
1175 }
1176 if (inputShape[i] != 1)
1177 readShape.push_back(inputShape[i]);
1178 }
1179
1180 Type elemType = packOp.getSourceType().getElementType();
1181 auto readType = RankedTensorType::get(readShape, elemType);
1182
1183 Value tile = rewriter.create<tensor::ExtractSliceOp>(
1184 loc, readType, input, readOffsets, readSizes, readStrides);
1185
1186 // 2. Transpose the tile to match the inner tile order.
1187
1188 SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
1189 inputShape, innerDimsPos, packOp.getOuterDimsPerm());
1190
1191 LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
1192 llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
1193
1194 SmallVector<int64_t> transpShape = readShape;
1195 applyPermutationToVector<int64_t>(inVec&: transpShape, permutation: perm);
1196
1197 Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
1198 auto transposedOp =
1199 rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
1200
1201 // 3. Insert the inner tile to the destination.
1202 int64_t destRank = packOp.getDestRank();
1203 SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1204 SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1205 SmallVector<OpFoldResult> writeSizes =
1206 tensor::getMixedSizes(builder&: rewriter, loc, value: packOp.getDest());
1207
1208 auto insert = rewriter.create<tensor::InsertSliceOp>(
1209 loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1210 writeSizes, writeStrides);
1211 rewriter.replaceOp(packOp, insert.getResult());
1212
1213 return success();
1214}
1215
1216LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
1217 tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1218 int64_t srcRank = unpackOp.getSourceRank();
1219 int64_t destRank = unpackOp.getDestRank();
1220 ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1221 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1222 if (llvm::any_of(Range&: innerDimsPos, P: [srcShape](int64_t index) {
1223 return srcShape[index] != 1;
1224 })) {
1225 return rewriter.notifyMatchFailure(
1226 unpackOp,
1227 "require the tiled outer dimensions of the result are all 1s");
1228 }
1229
1230 // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
1231 Location loc = unpackOp.getLoc();
1232 Value source = unpackOp.getSource();
1233 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1234 unpackOp.getDimAndTileMapping();
1235 Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1236 Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1237 SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
1238 SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
1239 SmallVector<OpFoldResult> readSizes;
1240 SmallVector<int64_t> readShape;
1241 SmallVector<Value> dynamicDims;
1242 for (auto i : llvm::seq<unsigned>(0, destRank)) {
1243 if (dimAndTileMapping.count(i)) {
1244 readSizes.push_back(oneIdxAttr);
1245 continue;
1246 }
1247
1248 if (ShapedType::isDynamic(srcShape[i])) {
1249 Value dynamicDim =
1250 rewriter.create<tensor::DimOp>(loc, source, i).getResult();
1251 readSizes.push_back(dynamicDim);
1252 dynamicDims.push_back(dynamicDim);
1253 } else {
1254 readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1255 }
1256 if (srcShape[i] != 1)
1257 readShape.push_back(srcShape[i]);
1258 }
1259 auto mixedTiles = unpackOp.getMixedTiles();
1260 readSizes.append(mixedTiles.begin(), mixedTiles.end());
1261
1262 // Explicitly create the type for extract_slice op because the inner tile
1263 // size could be 1. We want to represent the whole inner tile in this case.
1264 auto tileShape = srcShape.drop_front(N: destRank);
1265 // Append the inner tile shape to the permuted and rank-reduced outer shape.
1266 readShape.append(tileShape.begin(), tileShape.end());
1267 Type elemType = unpackOp.getSourceType().getElementType();
1268 auto readType = RankedTensorType::get(readShape, elemType);
1269 Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
1270 loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
1271
1272 // 2. Transpose the tile to match the outer corresponding tile order.
1273 SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
1274 srcShape.take_front(N: destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1275 // Unpack is a transition out of packed space so we invert the permutation.
1276 perm = invertPermutationVector(permutation: perm);
1277 SmallVector<int64_t> transpShape(readShape);
1278 applyPermutationToVector<int64_t>(inVec&: transpShape, permutation: perm);
1279
1280 Value empty =
1281 rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
1282 auto transposedOp =
1283 rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1284
1285 // 3. Handle in-complete tiles if needed. It truncates trailing data from the
1286 // transposed tile.
1287 int numLoops = transpShape.size();
1288 SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
1289 SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
1290 SmallVector<OpFoldResult> tileSizes;
1291 ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1292 for (auto i : llvm::seq<unsigned>(0, destRank)) {
1293 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1294 tileSizes.push_back(
1295 tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
1296 }
1297
1298 auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
1299 loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1300
1301 // 4. Insert the result to the destination tensor.
1302 SmallVector<OpFoldResult> writeSizes;
1303 SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1304 SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1305 for (int i = 0, idx = 0; i < destRank; ++i) {
1306 if (dimAndTileMapping.count(Val: i) || destShape[i] != 1)
1307 writeSizes.push_back(Elt: tileSizes[idx++]);
1308 else
1309 writeSizes.push_back(Elt: oneIdxAttr);
1310 }
1311 auto insert = rewriter.create<tensor::InsertSliceOp>(
1312 loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1313 writeStrides);
1314 rewriter.replaceOp(unpackOp, insert.getResult());
1315
1316 return success();
1317}
1318
1319// The following are patterns for downscaling convolution ops with size-1
1320// window dimensions.
1321//
1322// Note that we'd eventually want to write such transformations in a generic
1323// way, e.g., converting to linalg.generic, removing the size-1 dimensions,
1324// and then turning back to named ops. But for now it's fine to have a few
1325// patterns matching special ops to get started.
1326
1327template <typename Conv2DOp, typename Conv1DOp>
1328FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
1329 returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
1330 if (convOp.hasPureBufferSemantics())
1331 return failure(); // To be implemented.
1332
1333 Value input = convOp.getInputs().front();
1334 Value kernel = convOp.getInputs().back();
1335 Value output = convOp.getOutputs().front();
1336
1337 auto inputType = dyn_cast<RankedTensorType>(input.getType());
1338 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1339 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1340
1341 auto kernelShape = kernelType.getShape();
1342 auto outputShape = outputType.getShape();
1343
1344 // Get domain indices based on conv2D layout.
1345 auto [khIndex, kwIndex, ohIndex, owIndex] =
1346 TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>(
1347 convOp)
1348 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1349 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1350 })
1351 .Case([&](linalg::Conv2DNchwFchwOp op) {
1352 return std::make_tuple(args: 2, args: 3, args: 2, args: 3);
1353 })
1354 .Case([&](linalg::PoolingNhwcSumOp op) {
1355 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1356 })
1357 .Case([&](linalg::PoolingNchwSumOp op) {
1358 return std::make_tuple(args: 0, args: 1, args: 2, args: 3);
1359 })
1360 .Case([&](linalg::PoolingNhwcMaxOp op) {
1361 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1362 })
1363 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1364 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1365 })
1366 .Case([&](linalg::PoolingNhwcMinOp op) {
1367 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1368 })
1369 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1370 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1371 })
1372 .Case([&](linalg::PoolingNchwMaxOp op) {
1373 return std::make_tuple(args: 0, args: 1, args: 2, args: 3);
1374 })
1375 .Default([&](Operation *op) {
1376 llvm_unreachable("unexpected conv2d/pool2d operation.");
1377 return std::make_tuple(args: 0, args: 0, args: 0, args: 0);
1378 });
1379
1380 // Only handle the case where at least one of the window dimensions is
1381 // of size 1. Other cases can rely on tiling to reduce to such cases.
1382 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1383 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1384 bool removeH = (khSize == 1 && ohSize == 1);
1385 bool removeW = (kwSize == 1 && owSize == 1);
1386 if (!removeH && !removeW)
1387 return failure();
1388
1389 // Get new shapes and types for all operands by removing the size-1
1390 // dimension.
1391 using RTTBuilder = RankedTensorType::Builder;
1392 RankedTensorType newInputType =
1393 RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
1394 RankedTensorType newKernelType =
1395 RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
1396 RankedTensorType newOutputType =
1397 RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
1398
1399 // Rank-reduce operands.
1400 Location loc = convOp.getLoc();
1401 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1402 b&: rewriter, loc, tensor: input, targetType: newInputType);
1403 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1404 b&: rewriter, loc, tensor: kernel, targetType: newKernelType);
1405 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1406 b&: rewriter, loc, tensor: output, targetType: newOutputType);
1407
1408 // Rank-reduce strides and dilations too.
1409 // TODO: dropDim 1-liner helper.
1410 auto strides =
1411 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1412 strides.erase(strides.begin() + (removeH ? 0 : 1));
1413 auto stridesAttr = rewriter.getI64VectorAttr(values: strides);
1414
1415 auto dilations =
1416 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1417 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1418 auto dilationsAttr = rewriter.getI64VectorAttr(values: dilations);
1419
1420 auto conv1DOp = rewriter.create<Conv1DOp>(
1421 loc, newOutputType, ValueRange{newInput, newKernel},
1422 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1423
1424 // Insert back.
1425 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1426 b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output);
1427 rewriter.replaceOp(convOp, inserted);
1428
1429 return conv1DOp;
1430}
1431
1432template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1433 Conv1DNwcWcfOp>;
1434template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1435 Conv1DNcwFcwOp>;
1436template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
1437 PoolingNwcSumOp>;
1438template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
1439 PoolingNcwSumOp>;
1440template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
1441 PoolingNwcMaxOp>;
1442template struct linalg::DownscaleSizeOneWindowed2DConvolution<
1443 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1444template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
1445 PoolingNwcMinOp>;
1446template struct linalg::DownscaleSizeOneWindowed2DConvolution<
1447 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1448template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
1449 PoolingNcwMaxOp>;
1450
1451FailureOr<DepthwiseConv1DNwcWcOp>
1452DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
1453 DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
1454 if (convOp.hasPureBufferSemantics())
1455 return failure(); // To be implemented.
1456
1457 Value input = convOp.getInputs().front();
1458 Value kernel = convOp.getInputs().back();
1459 Value output = convOp.getOutputs().front();
1460
1461 auto inputType = dyn_cast<RankedTensorType>(input.getType());
1462 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1463 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1464
1465 auto kernelShape = kernelType.getShape();
1466 auto outputShape = outputType.getShape();
1467
1468 // Only handle the case where at least one of the window dimensions is
1469 // of size 1. Other cases can rely on tiling to reduce to such cases.
1470 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1471 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1472 bool removeH = (khSize == 1 && ohSize == 1);
1473 bool removeW = (kwSize == 1 && owSize == 1);
1474 if (!removeH && !removeW)
1475 return failure();
1476
1477 // Get new shapes and types for all operands by removing the size-1
1478 // dimension.
1479 using RTTBuilder = RankedTensorType::Builder;
1480 RankedTensorType newInputType =
1481 RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1482 RankedTensorType newKernelType =
1483 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1484 RankedTensorType newOutputType =
1485 RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1486
1487 // Rank-reduce operands.
1488 Location loc = convOp.getLoc();
1489 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1490 b&: rewriter, loc, tensor: input, targetType: newInputType);
1491 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1492 b&: rewriter, loc, tensor: kernel, targetType: newKernelType);
1493 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1494 b&: rewriter, loc, tensor: output, targetType: newOutputType);
1495
1496 // Rank-reduce strides and dilations too.
1497 // TODO: dropDim 1-liner helper.
1498 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1499 strides.erase(strides.begin() + (removeH ? 0 : 1));
1500 auto stridesAttr = rewriter.getI64VectorAttr(values: strides);
1501
1502 auto dilations =
1503 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1504 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1505 auto dilationsAttr = rewriter.getI64VectorAttr(values: dilations);
1506
1507 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1508 loc, newOutputType, ValueRange{newInput, newKernel},
1509 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1510
1511 // Insert back.
1512 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1513 b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output);
1514 rewriter.replaceOp(convOp, inserted);
1515
1516 return conv1DOp;
1517}
1518
1519FailureOr<Conv1DOp>
1520DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
1521 PatternRewriter &rewriter) const {
1522 if (convOp.hasPureBufferSemantics())
1523 return failure(); // To be implemented.
1524
1525 Value input = convOp.getInputs().front();
1526 Value kernel = convOp.getInputs().back();
1527 Value output = convOp.getOutputs().front();
1528
1529 auto inputType = dyn_cast<RankedTensorType>(input.getType());
1530 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1531 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1532
1533 auto kernelShape = kernelType.getShape();
1534 auto outputShape = outputType.getShape();
1535
1536 // Only handle the case where at least one of the window dimensions is
1537 // of size 1. Other cases can rely on tiling to reduce to such cases.
1538 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1539 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1540 bool removeH = (khSize == 1 && ohSize == 1);
1541 bool removeW = (kwSize == 1 && owSize == 1);
1542 if (!removeH && !removeW)
1543 return failure();
1544
1545 // Get new shapes and types for all operands by removing the size-1
1546 // dimension.
1547 using RTTBuilder = RankedTensorType::Builder;
1548 RankedTensorType newInputType =
1549 RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
1550 RankedTensorType newKernelType =
1551 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1552 RankedTensorType newOutputType =
1553 RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
1554
1555 // Rank-reduce operands.
1556 Location loc = convOp.getLoc();
1557 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1558 b&: rewriter, loc, tensor: input, targetType: newInputType);
1559 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1560 b&: rewriter, loc, tensor: kernel, targetType: newKernelType);
1561 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1562 b&: rewriter, loc, tensor: output, targetType: newOutputType);
1563
1564 auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType,
1565 ValueRange{newInput, newKernel},
1566 ValueRange{newOutput});
1567
1568 // Insert back.
1569 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1570 b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output);
1571 rewriter.replaceOp(convOp, inserted);
1572
1573 return conv1DOp;
1574}
1575
1576void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
1577 PatternBenefit benefit) {
1578 patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
1579 Conv1DNwcWcfOp>,
1580 DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
1581 Conv1DNcwFcwOp>,
1582 DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>(
1583 patterns.getContext(), benefit);
1584 patterns.add<
1585 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
1586 DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
1587 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>,
1588 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
1589 PoolingNwcMaxUnsignedOp>,
1590 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>,
1591 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
1592 PoolingNwcMinUnsignedOp>,
1593 DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
1594 patterns.getContext(), benefit);
1595}
1596

source code of mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp