1//===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic and helpers to expose Linalg transforms as rewrite
10// patterns.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Utils/Utils.h"
20#include "mlir/Dialect/SCF/Transforms/Transforms.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
23#include "mlir/Dialect/Tensor/Utils/Utils.h"
24#include "mlir/Dialect/Utils/IndexingUtils.h"
25#include "mlir/Dialect/Utils/StaticValueUtils.h"
26#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
27#include "mlir/Dialect/Vector/IR/VectorOps.h"
28#include "mlir/IR/AffineExpr.h"
29#include "mlir/IR/Matchers.h"
30#include "mlir/Pass/Pass.h"
31#include "mlir/Support/LLVM.h"
32#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33#include "llvm/ADT/ScopeExit.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/InterleavedRange.h"
37#include "llvm/Support/raw_ostream.h"
38#include <type_traits>
39#include <utility>
40
41#define DEBUG_TYPE "linalg-transforms"
42
43using namespace mlir;
44using namespace mlir::linalg;
45
46#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
47#define DBGSNL() (llvm::dbgs() << "\n")
48
49//===----------------------------------------------------------------------===//
50// Transformations exposed as functional-style API calls.
51//===----------------------------------------------------------------------===//
52
53//===----------------------------------------------------------------------===//
54// peelLoop transformation.
55//===----------------------------------------------------------------------===//
56
57/// Try to peel and canonicalize loop `op` and return the new result.
58/// Also applies affine_min/max bounds simplification on the fly where relevant.
59// TODO: Add support for scf.parallel and affine.for loops.
60SmallVector<Value> mlir::linalg::peelLoop(RewriterBase &rewriter,
61 Operation *op) {
62 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
63 .Case<scf::ForOp>(caseFn: [&](scf::ForOp forOp) {
64 scf::ForOp partialIteration;
65 if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
66 partialIteration)))
67 return partialIteration->getResults();
68 assert(!partialIteration && "expected that loop was not peeled");
69 return forOp->getResults();
70 })
71 .Default(defaultFn: [&](Operation *op) { return op->getResults(); });
72}
73
74/// Peel 'loops' and applies affine_min/max bounds simplification on the fly
75/// where relevant.
76void mlir::linalg::peelLoops(RewriterBase &rewriter,
77 ArrayRef<scf::ForOp> loops) {
78 for (auto loopOp : loops)
79 peelLoop(rewriter, loopOp);
80}
81
82//===----------------------------------------------------------------------===//
83// pack transformation.
84//===----------------------------------------------------------------------===//
85
86#ifndef NDEBUG
87/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
88static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
89 bool found = false;
90 for (AffineExpr e : map.getResults()) {
91 if (!e.isFunctionOfDim(position: dim))
92 continue;
93 if (found)
94 return false;
95 found = true;
96 }
97 return true;
98}
99
100static std::string stringifyReassocIndices(ReassociationIndicesRef ri) {
101 return llvm::interleaved(R: ri, Separator: ", ", /*Prefix=*/"|", /*Suffix=*/"");
102}
103#endif // NDEBUG
104
105/// Return the index of the first result of `map` that is a function of
106/// AffineDimExpr(dim), std::nullopt otherwise.
107static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
108 int64_t dim) {
109 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
110 AffineExpr expr = map.getResult(idx: i);
111 if (!expr.isFunctionOfDim(position: dim))
112 continue;
113 return i;
114 }
115 return std::nullopt;
116}
117
118/// Perform one step of packing of a LinalgOp's metadata along `dim` into the
119/// `newDim` at `iteratorTypes.size()` by:
120/// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
121/// 2. Appending a `newDim` to the domain of every indexing map.
122/// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing
123/// by potentially adding a `newDim` result to `map`.
124/// The preserved invariant is that `iteratorTypes.size()` is always equal to
125/// `map.getNumDims()` for every map in `indexingMaps`.
126///
127/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
128/// Return a vector that records the optional packing for each operand.
129/// Return failure if the packed indexing cannot be represented with a LinalgOp.
130///
131/// Further details:
132/// ================
133/// The current implementation of packing (i.e. data tiling) consists of
134/// rewriting a linearized strip-mined form into a higher-dimensional access.
135/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
136/// `I` into `4 * i + ii`, where `0 <= ii < 4`.
137/// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
138///
139/// This rewrite into higher dimensional access is not possible for general
140/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
141/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
142/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
143/// The rewrite of the access would be a form not representable in Linalg:
144/// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
145/// Note however that as `J` and `ii` iterate, the accesses do not have a
146/// particular alignment, so packing does not achieve alignment in this case
147///
148/// In the future, we may want to consider a mixed-form that allows some
149/// alignment in the presence of multiple accesses:
150/// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
151/// And would rewrite accesses as:
152/// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
153static FailureOr<SmallVector<std::optional<int64_t>>>
154packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps,
155 SmallVectorImpl<utils::IteratorType> &iteratorTypes,
156 int64_t dim) {
157 int64_t newDim = iteratorTypes.size();
158 iteratorTypes.push_back(iteratorTypes[dim]);
159
160 SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
161 indexingMaps.size(), std::nullopt);
162 SmallVector<AffineMap> newMaps;
163 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
164 ++operandIdx) {
165 AffineMap map = indexingMaps[operandIdx];
166
167 // Add the `newDim` to map whatever the case.
168 assert(map.getNumDims() == newDim && "num dims invariant violation");
169 map = map.shiftDims(shift: 1, offset: newDim);
170
171 // Get the at-most-1 index of the result that is a function of `dim`.
172 // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
173 // logically chunks dimension `dim` into `K * dim + newDim`, where the
174 // packing factor `K` is specified separately.
175 assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
176 "num results invariant violation");
177 auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
178 if (!maybeOperandDimensionToPack.has_value()) {
179 newMaps.push_back(Elt: map);
180 continue;
181 }
182
183 // We can only pack AffineDimExpr atm.
184 if (!isa<AffineDimExpr>(Val: map.getResult(idx: maybeOperandDimensionToPack.value())))
185 return failure();
186
187 // Add `newDim` to the results of the map.
188 map = map.insertResult(expr: Builder(map.getContext()).getAffineDimExpr(position: newDim),
189 pos: map.getNumResults());
190 newMaps.push_back(Elt: map);
191
192 // Record the that `operandIdx` is packed.
193 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
194 }
195 indexingMaps = newMaps;
196
197 return packedDimPerIndexingMap;
198}
199
200namespace {
201
202/// Helper struct to encode packing along one dimension of a LinalgOp.
203struct PackedOperandsDim {
204 OpFoldResult packedSize;
205 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
206};
207
208/// Helper struct to encode packing along all dimensions of a LinalgOp.
209struct PackedOperandsDimList {
210 void pushBack(PackedOperandsDim &&packedOperandsDims) {
211 spec.emplace_back(Args&: packedOperandsDims);
212 }
213 /// Return all the dims that have been packed for operand @ `operandPos`.
214 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
215 /// Return all the pack sizes by which an operand @ `operandPos` is packed.
216 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
217
218private:
219 SmallVector<PackedOperandsDim> spec;
220};
221
222} // namespace
223
224FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
225 linalg::PackOp packOp,
226 bool lowerPadLikeWithInsertSlice) {
227 // 1. Filter out NYI cases.
228 auto packedTensorType =
229 cast<RankedTensorType>(packOp->getResultTypes().front());
230 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
231 return rewriter.notifyMatchFailure(
232 packOp,
233 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
234 }
235
236 Location loc = packOp->getLoc();
237 OpBuilder::InsertionGuard g(rewriter);
238 rewriter.setInsertionPoint(packOp);
239
240 // 2. Compute the permutation vector to shuffle packed shape into the shape
241 // before any outer or inner permutations have been applied.
242 PackingMetadata packingMetadata = computePackingMetadata(
243 packedTensorType.getRank(), packOp.getInnerDimsPos());
244 SmallVector<int64_t> packedToStripMinedShapePerm =
245 getPackInverseDestPerm(packOp);
246
247 // 3. Compute the stripMinedShape: this is the packed shape before any outer
248 // or inner permutations have been applied.
249 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
250 applyPermutationToVector(inVec&: stripMinedShape, permutation: packedToStripMinedShapePerm);
251
252 // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
253 SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
254 rewriter.getIndexAttr(0));
255 SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
256 rewriter.getIndexAttr(0));
257 for (auto [pos, innerSize] :
258 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
259 int outerPos =
260 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
261 OpFoldResult origSize =
262 tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
263 OpFoldResult outerSize =
264 tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
265 AffineExpr s0, d0, d1;
266 bindDims(rewriter.getContext(), d0, d1);
267 bindSymbols(rewriter.getContext(), s0);
268 auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
269 highs[pos] = affine::makeComposedFoldedAffineApply(
270 rewriter, loc, map, {outerSize, origSize, innerSize});
271 }
272 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
273 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
274 packingMetadata.reassociations);
275 Value paddingValue = packOp.getPaddingValue();
276 if (!paddingValue) {
277 paddingValue = rewriter.create<arith::ConstantOp>(
278 loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
279 }
280 auto padOp =
281 rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
282 highs, paddingValue, /*nofold=*/false);
283
284 LLVM_DEBUG(
285 DBGSNL(); DBGSNL();
286 DBGS() << "insertPositions: "
287 << llvm::interleaved(packingMetadata.insertPositions);
288 DBGSNL(); DBGS() << "outerPositions: "
289 << llvm::interleaved(packingMetadata.outerPositions);
290 DBGSNL(); DBGS() << "packedShape: "
291 << llvm::interleaved(packedTensorType.getShape());
292 DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
293 << llvm::interleaved(packedToStripMinedShapePerm);
294 DBGSNL();
295 DBGS() << "reassociations: "
296 << llvm::interleaved(llvm::map_range(
297 packingMetadata.reassociations, stringifyReassocIndices));
298 DBGSNL();
299 DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
300 DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
301
302 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
303 // Pack ops which operate as simple pads may not produce legal
304 // tensor.insert_slice operations when the packed type does not rank reduce
305 // to the padded type.
306 SliceVerificationResult rankReduces =
307 isRankReducedType(packedTensorType, padOp.getResultType());
308
309 if (rankReduces == SliceVerificationResult::Success) {
310 // This pack is just a plain pad.
311 // Just insert the pad in the higher ranked tensor.
312 // Offsets.
313 SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
314 rewriter.getIndexAttr(0));
315 // Strides.
316 SmallVector<OpFoldResult> ones(packOp.getDestRank(),
317 rewriter.getIndexAttr(1));
318 SmallVector<OpFoldResult> sizes =
319 tensor::getMixedSizes(builder&: rewriter, loc, value: packOp.getDest());
320
321 auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
322 loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
323 /*offsets=*/zeros, sizes, /*strides=*/ones);
324
325 LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
326
327 rewriter.replaceOp(packOp, insertSliceOp->getResults());
328
329 return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
330 /*transposeOp=*/nullptr};
331 }
332 }
333
334 // 5. Expand from the padded result to the stripMinedShape.
335 auto expandShapeResultType =
336 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
337 auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
338 loc, expandShapeResultType, padOp.getResult(),
339 packingMetadata.reassociations);
340
341 // 6. Transpose stripMinedShape to packedShape.
342 SmallVector<int64_t> transpPerm =
343 invertPermutationVector(permutation: packedToStripMinedShapePerm);
344 auto transposeOp = rewriter.create<linalg::TransposeOp>(
345 loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
346
347 LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
348 DBGS() << "reshape op: " << reshapeOp; DBGSNL();
349 DBGS() << "transpPerm: " << llvm::interleaved(transpPerm);
350 DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
351
352 // 7. Replace packOp by transposeOp.
353 rewriter.replaceOp(packOp, transposeOp->getResults());
354
355 return LowerPackResult{padOp, reshapeOp, transposeOp};
356}
357
358FailureOr<LowerUnPackOpResult>
359linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
360 bool lowerUnpadLikeWithExtractSlice) {
361 Location loc = unPackOp->getLoc();
362 OpBuilder::InsertionGuard g(rewriter);
363 rewriter.setInsertionPoint(unPackOp);
364
365 RankedTensorType packedTensorType = unPackOp.getSourceType();
366 int64_t packedRank = packedTensorType.getRank();
367
368 OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
369 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
370 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
371 // This unpack is just a plain unpad.
372 // Just extract the slice from the higher ranked tensor.
373 ArrayRef<int64_t> destShape = destTensorType.getShape();
374 // The inner dimensions stay the same as the destination tensor, but the
375 // outer ones are additional 1s.
376 SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
377 sizes.append(tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getDest()));
378
379 auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
380 loc, destTensorType, unPackOp.getSource(),
381 SmallVector<OpFoldResult>(packedRank, zero), sizes,
382 SmallVector<OpFoldResult>(packedRank, one));
383
384 rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
385
386 return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
387 /*reshapeOp=*/nullptr, extractSliceOp};
388 }
389
390 // 1. Compute the permutation vector to shuffle packed shape into the shape
391 // before any outer or inner permutations have been applied.
392 PackingMetadata packingMetadata;
393 SmallVector<int64_t> packedToStripMinedShapePerm =
394 getUnPackInverseSrcPerm(unPackOp, packingMetadata);
395
396 // 2. Compute the stripMinedShape: this is the packed shape without outer and
397 // inner permutations.
398 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
399 applyPermutationToVector(inVec&: stripMinedShape, permutation: packedToStripMinedShapePerm);
400
401 // 3. Transpose packedShape to stripMinedShape.
402 RankedTensorType stripMinedTensorType =
403 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
404 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
405 stripMinedTensorType, packingMetadata.reassociations);
406
407 // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
408 // permutation.
409 SmallVector<OpFoldResult, 4> dims =
410 tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getSource());
411 applyPermutationToVector(inVec&: dims, permutation: packedToStripMinedShapePerm);
412 auto emptyOp = rewriter.create<tensor::EmptyOp>(
413 loc, dims, stripMinedTensorType.getElementType());
414 auto transposeOp = rewriter.create<linalg::TransposeOp>(
415 loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
416
417 LLVM_DEBUG(
418 DBGSNL(); DBGSNL();
419 DBGS() << "insertPositions: "
420 << llvm::interleaved(packingMetadata.insertPositions);
421 DBGSNL(); DBGS() << "packedShape: "
422 << llvm::interleaved(packedTensorType.getShape());
423 DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
424 << llvm::interleaved(packedToStripMinedShapePerm);
425 DBGSNL();
426 DBGS() << "reassociations: "
427 << llvm::interleaved(llvm::map_range(
428 packingMetadata.reassociations, stringifyReassocIndices));
429 DBGSNL();
430 DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
431 DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
432
433 // 4. Collapse from the stripMinedShape to the padded result.
434 auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
435 loc, collapsedType, transposeOp->getResult(0),
436 packingMetadata.reassociations);
437
438 // 5. ExtractSlice.
439 int64_t destRank = destTensorType.getRank();
440 auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
441 loc, destTensorType, reshapeOp->getResult(0),
442 SmallVector<OpFoldResult>(destRank, zero),
443 tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getDest()),
444 SmallVector<OpFoldResult>(destRank, one));
445
446 // 6. Inject a copy to preserve DPS.
447 auto copyOp = rewriter.create<linalg::CopyOp>(
448 loc, extractSliceOp->getResult(0), unPackOp.getDest());
449
450 // 7. Replace unPackOp by copyOp.
451 rewriter.replaceOp(unPackOp, copyOp->getResults());
452
453 return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
454}
455
456SmallVector<int64_t>
457PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
458 SmallVector<int64_t> res;
459 for (auto &i : spec) {
460 if (!i.packedDimForEachOperand[operandPos].has_value())
461 continue;
462 res.push_back(Elt: i.packedDimForEachOperand[operandPos].value());
463 }
464 return res;
465}
466
467SmallVector<OpFoldResult>
468PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
469 SmallVector<OpFoldResult> res;
470 for (auto &i : spec) {
471 if (!i.packedDimForEachOperand[operandPos].has_value())
472 continue;
473 res.push_back(Elt: i.packedSize);
474 }
475 return res;
476}
477
478/// Implement packing of a single LinalgOp by performing packing by
479/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
480/// Return the packed Linalg op on success, failure otherwise.
481FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
482 linalg::LinalgOp linalgOp,
483 ArrayRef<OpFoldResult> packedSizes) {
484 if (packedSizes.size() != linalgOp.getNumLoops()) {
485 return rewriter.notifyMatchFailure(linalgOp,
486 "incorrect number of pack sizes");
487 }
488
489 Location loc = linalgOp->getLoc();
490 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
491 SmallVector<utils::IteratorType> iteratorTypes =
492 linalgOp.getIteratorTypesArray();
493 LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"
494 << "maps: " << llvm::interleaved(indexingMaps) << "\n"
495 << "iterators: " << llvm::interleaved(iteratorTypes)
496 << "\n");
497
498 SmallVector<linalg::PackOp> packOps;
499 SmallVector<linalg::UnPackOp> unPackOps;
500 // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
501 PackedOperandsDimList listOfPackedOperandsDim;
502 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
503 std::optional<int64_t> maybeConstant = getConstantIntValue(ofr: packedSizes[i]);
504 // Skip tile sizes explicitly set to 0.
505 if (maybeConstant.has_value() && maybeConstant.value() == 0)
506 continue;
507
508 PackedOperandsDim packedOperandsDims;
509 packedOperandsDims.packedSize = packedSizes[i];
510 FailureOr<SmallVector<std::optional<int64_t>>>
511 maybePackedDimForEachOperand =
512 packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
513 if (failed(Result: maybePackedDimForEachOperand))
514 return failure();
515 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
516 listOfPackedOperandsDim.pushBack(packedOperandsDims: std::move(packedOperandsDims));
517
518 LLVM_DEBUG(
519 DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
520 << "\n"
521 << "maps: " << llvm::interleaved(indexingMaps) << "\n"
522 << "iterators: " << llvm::interleaved(iteratorTypes) << "\n"
523 << "packedDimForEachOperand: "
524 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand)
525 << "\n");
526 }
527
528 // Step 2. Propagate packing to all LinalgOp operands.
529 SmallVector<Value> inputsAndInits, results;
530 SmallVector<OpOperand *> initOperands =
531 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
532 SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
533 for (const auto &operandsList : {inputOperands, initOperands}) {
534 for (OpOperand *opOperand : operandsList) {
535 int64_t pos = opOperand->getOperandNumber();
536 Value operand = opOperand->get();
537 SmallVector<int64_t> innerPos =
538 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
539 SmallVector<OpFoldResult> innerPackSizes =
540 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
541 LLVM_DEBUG(DBGS() << "operand: " << operand << "\n"
542 << "innerPos: " << llvm::interleaved(innerPos) << "\n"
543 << "innerPackSizes: "
544 << llvm::interleaved(innerPackSizes) << "\n");
545 if (innerPackSizes.empty()) {
546 inputsAndInits.push_back(operand);
547 continue;
548 }
549 Value dest = linalg::PackOp::createDestinationTensor(
550 rewriter, loc, operand, innerPackSizes, innerPos,
551 /*outerDimsPerm=*/{});
552 ShapedType operandType = cast<ShapedType>(operand.getType());
553 bool areConstantTiles =
554 llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
555 return getConstantIntValue(tile).has_value();
556 });
557 if (areConstantTiles && operandType.hasStaticShape() &&
558 !linalg::PackOp::requirePaddingValue(
559 operandType.getShape(), innerPos,
560 cast<ShapedType>(dest.getType()).getShape(), {},
561 innerPackSizes)) {
562 packOps.push_back(rewriter.create<linalg::PackOp>(
563 loc, operand, dest, innerPos, innerPackSizes));
564 } else {
565 // TODO: value of the padding attribute should be determined by
566 // consumers.
567 auto zeroAttr =
568 rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
569 Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
570 packOps.push_back(rewriter.create<linalg::PackOp>(
571 loc, operand, dest, innerPos, innerPackSizes, zero));
572 }
573 inputsAndInits.push_back(packOps.back());
574 }
575 }
576
577 // Step 3. Build the packed op, use the type of `inits` as result types.
578 ValueRange inputs =
579 ValueRange{inputsAndInits}.take_front(n: linalgOp.getNumDpsInputs());
580 ValueRange inits =
581 ValueRange{inputsAndInits}.take_back(n: linalgOp.getNumDpsInits());
582 auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
583 linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
584 iteratorTypes);
585 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
586
587 // Step 4. Propagate packing to all the op results.
588 for (OpResult result : packedLinalgOp->getResults()) {
589 int64_t resultNum = result.getResultNumber();
590 linalg::PackOp maybePackedInit =
591 inits[resultNum].getDefiningOp<linalg::PackOp>();
592 if (!maybePackedInit) {
593 results.push_back(result);
594 continue;
595 }
596 // Build the symmetrical UnPackOp to the existing PackOp.
597 unPackOps.push_back(rewriter.create<linalg::UnPackOp>(
598 packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
599 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
600 results.push_back(unPackOps.back());
601 }
602
603 // Step 5. Replace `linalgOp`.
604 rewriter.replaceOp(linalgOp, results);
605
606 // Return packedLinalgOp.
607 return PackResult{packOps,
608 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
609 unPackOps};
610}
611
612//===----------------------------------------------------------------------===//
613// packTranspose transformation.
614//===----------------------------------------------------------------------===//
615
616/// Return a copy of `tensorType` after permutation by `permutationVector`.
617// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
618// but this would introduce a dependence on Dialect in IR.
619// TODO: Restructure.
620static RankedTensorType permuteShape(RankedTensorType tensorType,
621 ArrayRef<int64_t> permutationVector) {
622 SmallVector<int64_t> shape(tensorType.getShape());
623 applyPermutationToVector(inVec&: shape, permutation: permutationVector);
624 return RankedTensorType::Builder(tensorType).setShape(shape);
625}
626
627/// Return a new GenericOp obtained by transposing opOperand by the permutation
628/// vector:
629/// - the corresponding indexing map is transposed by `permutation`
630/// - the corresponding operand value is replaced by `transposedValue`
631/// `linalgOp` is replaced by the return op in the process.
632/// Asserts that `transposedValue` is of the proper transposed ShapedType.
633static LinalgOp transposeOneLinalgOperandAndReplace(
634 RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
635 ArrayRef<int64_t> permutation, Value transposedValue) {
636 // Sanity check the operand.
637 assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
638
639 // Sanity check of the expected transposed tensor type.
640 auto tensorType = permuteShape(
641 cast<RankedTensorType>(opOperand.get().getType()), permutation);
642 (void)tensorType;
643 assert(tensorType == transposedValue.getType() &&
644 "expected tensor type mismatch");
645
646 // Compute the transposed indexing map.
647 // Sigh unsigned pollution.
648 SmallVector<unsigned> tmpTransposition = llvm::to_vector(
649 Range: llvm::map_range(C&: permutation, F: [](int64_t i) -> unsigned { return i; }));
650 AffineMap permutationMap =
651 AffineMap::getPermutationMap(permutation: tmpTransposition, context: rewriter.getContext());
652 AffineMap transposedMap =
653 permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
654
655 // Set the transposed indexing map in the proper position.
656 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
657 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
658 // Set the transposedValue in the proper operand position.
659 SmallVector<Value> operands = linalgOp->getOperands();
660 operands[opOperand.getOperandNumber()] = transposedValue;
661
662 ValueRange operandsRef(operands);
663 auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
664 /*location=*/linalgOp->getLoc(),
665 /*resultTensorTypes=*/
666 operandsRef.drop_front(n: linalgOp.getNumDpsInputs()).getTypes(),
667 /*inputs=*/operandsRef.take_front(n: linalgOp.getNumDpsInputs()),
668 /*outputs=*/operandsRef.drop_front(n: linalgOp.getNumDpsInputs()),
669 /*indexingMaps=*/indexingMaps,
670 /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
671 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
672 rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
673
674 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
675}
676
677FailureOr<PackTransposeResult>
678linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
679 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
680 ArrayRef<int64_t> outerPerm,
681 ArrayRef<int64_t> innerPerm) {
682 Location loc = linalgOp.getLoc();
683
684 // Step 1. Transpose packOp.
685 rewriter.setInsertionPoint(packOp);
686 linalg::PackOp transposedPackOp =
687 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
688
689 if (!packOp.getResult().hasOneUse())
690 return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
691
692 OpOperand &packUse = *packOp->getUses().begin();
693 if (packUse.getOwner() != linalgOp) {
694 return rewriter.notifyMatchFailure(
695 linalgOp, "not a single use by the LinalgOp target");
696 }
697 if (maybeUnPackOp &&
698 (!linalgOp.isDpsInit(&packUse) ||
699 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
700 return rewriter.notifyMatchFailure(linalgOp,
701 "not produced by the LinalgOp target");
702 }
703
704 // Step 2. Transpose linalgOp.
705 // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
706 // identity. Don't rely on it.
707 int64_t numLeadingDims = packOp.getSourceRank();
708 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
709 // Step 2.a. Compute the permutation on the whole operand.
710 // Leading part just reuse the outerPerm.
711 SmallVector<int64_t> permutation(outerPerm);
712 if (permutation.empty())
713 llvm::append_range(C&: permutation, R: llvm::seq<int64_t>(Begin: 0, End: numLeadingDims));
714 // Trailing part needs to reindex positions by `numLeadingDims`.
715 if (innerPerm.empty()) {
716 llvm::append_range(
717 C&: permutation,
718 R: llvm::seq<int64_t>(Begin: numLeadingDims, End: numLeadingDims + numTrailingDims));
719 } else {
720 llvm::append_range(permutation,
721 llvm::map_range(innerPerm, [&](int64_t pos) {
722 return numLeadingDims + pos;
723 }));
724 }
725 if (!isPermutationVector(interchange: permutation))
726 return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
727
728 // Step 2.b. Save the transposedPackUse operand number in case we need to
729 // get the tied OpResult after `linalgOp` has been replaced.
730 int64_t packUseOperandNumber = packUse.getOperandNumber();
731 // Step 2.c. Actually perform the transposition.
732 rewriter.setInsertionPoint(linalgOp);
733 linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
734 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
735
736 // Step 3. Maybe transpose unPackOp.
737 linalg::UnPackOp transposedUnPackOp;
738 if (maybeUnPackOp) {
739 OpOperand &opOperand =
740 transposedLinalgOp->getOpOperand(packUseOperandNumber);
741 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
742 rewriter.setInsertionPoint(maybeUnPackOp);
743 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
744 rewriter, loc, transposedResult, innerPerm, outerPerm);
745
746 rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
747 }
748
749 // Step 4. Finally, replace packOp now that we don't need it anymore.
750 rewriter.replaceOp(packOp, transposedPackOp->getResults());
751
752 return PackTransposeResult{transposedPackOp, transposedLinalgOp,
753 transposedUnPackOp};
754}
755
756//===----------------------------------------------------------------------===//
757// packMatmulGreedily transformation.
758//===----------------------------------------------------------------------===//
759
760/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
761/// and n are proper parallel dimensions and k is a proper reduction
762/// dimension. Packing occurs by rewriting the op as a linalg.generic and
763/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
764/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
765/// to reorder {m, n, k} into one of the 8 possible forms. The outer
766/// dimensions of the operands are not permuted at this time, this is left for
767/// future work.
768FailureOr<PackResult>
769linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
770 ArrayRef<OpFoldResult> mnkPackedSizes,
771 ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
772 ArrayRef<int64_t> mnkOrder) {
773 assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
774 assert((mnkPaddedSizesNextMultipleOf.empty() ||
775 mnkPaddedSizesNextMultipleOf.size() == 3) &&
776 "num of packing sizes next multiple should be empty or of size 3");
777 assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
778 assert(isPermutationVector(mnkOrder) && "expected a permutation");
779
780 int64_t numLoops = linalgOp.getNumLoops();
781 if (numLoops <= 2) {
782 LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
783 << numLoops << "\nin: " << linalgOp << "\n");
784 return rewriter.notifyMatchFailure(
785 linalgOp, "need 3+ loops to find a matmul to pack");
786 }
787
788 // Locally adjust the desired iterator position of mnk and packing sizes.
789 int64_t numPackedDims = mnkPackedSizes.size();
790 SmallVector<int64_t> mmnnkkPos(numPackedDims);
791 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
792 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
793 SmallVector<OpFoldResult> packedSizes(numPackedDims);
794 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
795 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
796 SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
797 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
798 paddedSizesNextMultipleOf[mnkOrder[i]] =
799 mnkPaddedSizesNextMultipleOf.empty() ? 0
800 : mnkPaddedSizesNextMultipleOf[i];
801 }
802
803 // 1. Infer dims that are important for matmul.
804 FailureOr<ContractionDimensions> maybeDimensions =
805 inferContractionDims(linalgOp);
806 if (failed(Result: maybeDimensions)) {
807 LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
808 << "\n");
809 return rewriter.notifyMatchFailure(linalgOp,
810 "couldn't infer matmul iterators");
811 }
812
813 // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
814 // minor iterators. In cases with multiple options for m, n, k bias towards
815 // the most minor embedding.
816 // If we wanted a different normalization order, this is where it would have
817 // to plug a heuristic.
818 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
819 kPos = maybeDimensions->k.back();
820 LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
821 DBGS() << "Start packing generic op greedily with (m@" << mPos
822 << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
823 << "\n";);
824
825 // 2.a. Rewrite as a generic.
826 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
827 if (!genericOp) {
828 FailureOr<GenericOp> generalizeResult =
829 generalizeNamedOp(rewriter, linalgOp);
830 assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
831 genericOp = *generalizeResult;
832 }
833
834 // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
835 // iterators. Note that this only normalized the iteration order and does
836 // not change the indexings of any operand.
837 SmallVector<int64_t> permutation =
838 computePermutationVector(permSize: numLoops, positions: {mPos, nPos, kPos}, desiredPositions: mmnnkkPos);
839 LLVM_DEBUG(DBGS() << "perm: " << llvm::interleaved(permutation) << "\n");
840 // Sign .. unsigned pollution.
841 SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
842 FailureOr<GenericOp> interchangeResult =
843 interchangeGenericOp(rewriter, genericOp, unsignedPerm);
844 assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
845 genericOp = *interchangeResult;
846 LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
847
848 // At this point, the op iterators are normalized to {leading, k, m, n}.
849 // The layouts induced by packing will always be:
850 // - LHS{leading_lhs, kk, mm}
851 // - RHS{leading_rhs, kk, nn}
852 // - RES{leading_res, mm, nn}
853 // If we wanted to change the packed order, we would reorder (k, m, n) to
854 // something else above.
855 //
856 // Additional permutations of the outer dims of the operands (i.e.
857 // leading_lhs, leading_rhs and leading_res) could follow by computing the
858 // desired outerPerm for each operand.
859 // This is left for future work.
860
861 // TODO: this creates too much IR, go use reifyResultShapes.
862 SmallVector<Range, 4> loopRanges =
863 cast<LinalgOp>(genericOp.getOperation())
864 .createLoopRanges(rewriter, genericOp.getLoc());
865
866 // Add leading zeros to match numLoops, we only pack the last 3 dimensions
867 // post interchange.
868 LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: "
869 << llvm::interleaved(paddedSizesNextMultipleOf) << "\n"
870 << "loopRanges: "
871 << llvm::interleaved(llvm::map_range(
872 loopRanges, [](Range r) { return r.size; }))
873 << "\n");
874 SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
875 rewriter.getIndexAttr(0));
876 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
877 if (paddedSizesNextMultipleOf[i] == 0) {
878 adjustedPackedSizes.push_back(Elt: packedSizes[i]);
879 continue;
880 }
881 AffineExpr d0, s0;
882 bindDims(ctx: rewriter.getContext(), exprs&: d0);
883 bindSymbols(ctx: rewriter.getContext(), exprs&: s0);
884 adjustedPackedSizes.push_back(Elt: affine::makeComposedFoldedAffineApply(
885 rewriter, genericOp->getLoc(), d0.ceilDiv(other: s0) * s0,
886 {loopRanges[adjustedPackedSizes.size()].size,
887 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
888 }
889 LLVM_DEBUG(DBGS() << "adjustedPackedSizes: "
890 << llvm::interleaved(adjustedPackedSizes) << "\n");
891
892 // TODO: If we wanted to give the genericOp a name after packing, after
893 // calling `pack` would be a good time. One would still need to check that
894 // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
895 // also allow degenerate matmul cases (i.e. matvec, dot).
896 return pack(rewriter, genericOp, adjustedPackedSizes);
897}
898
899//===----------------------------------------------------------------------===//
900// Transformations exposed as rewrite patterns.
901//===----------------------------------------------------------------------===//
902
903LinalgTilingOptions &
904mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
905 assert(!tileSizeComputationFunction && "tile sizes already set");
906 SmallVector<int64_t, 4> tileSizes(ts);
907 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
908 OpBuilder::InsertionGuard guard(b);
909 b.setInsertionPointToStart(
910 &op->getParentOfType<func::FuncOp>().getBody().front());
911 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
912 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
913 return v;
914 }));
915 };
916 return *this;
917}
918
919LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
920 memref::CopyOp copyOp, PatternRewriter &rewriter) const {
921 return vectorizeCopy(rewriter, copyOp);
922}
923
924/// Filling `dest` using FillOp constant padding value if possible.
925/// Otherwise, generate a tensor::GenerateOp.
926Value DecomposePadOpPattern::createFillOrGenerateOp(
927 RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
928 const SmallVector<Value> &dynSizes) const {
929 auto padValue = padOp.getConstantPaddingValue();
930 if (padValue) {
931 // Move the padding value defined inside the PadOp block to outside.
932 if (padValue.getParentBlock() == &padOp.getRegion().front())
933 rewriter.moveOpBefore(padValue.getDefiningOp(), padOp);
934 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
935 }
936
937 // Fill could not be optimized: Lower to tensor::GenerateOp with region.
938 auto generateOp = rewriter.create<tensor::GenerateOp>(
939 padOp.getLoc(), padOp.getResultType(), dynSizes);
940 // Copy region to new op.
941 IRMapping bvm;
942 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
943 return generateOp;
944}
945
946LogicalResult
947DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
948 PatternRewriter &rewriter) const {
949 // Given an OpFoldResult, return an index-typed value.
950 auto getIdxValue = [&](OpFoldResult ofr) {
951 if (auto val = llvm::dyn_cast_if_present<Value>(Val&: ofr))
952 return val;
953 return rewriter
954 .create<arith::ConstantIndexOp>(
955 padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
956 .getResult();
957 };
958
959 auto resultType = padOp.getResultType();
960 // Compute size of EmptyOp. Any combination of static/dynamic is supported.
961 SmallVector<Value> dynSizes;
962 SmallVector<int64_t> staticSizes;
963 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
964 if (resultType.isDynamicDim(dim)) {
965 auto srcSize = getIdxValue(tensor::getMixedSize(builder&: rewriter, loc: padOp.getLoc(),
966 value: padOp.getSource(), dim));
967 // Add low and high padding value.
968 auto plusLow = rewriter.createOrFold<arith::AddIOp>(
969 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
970 auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
971 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
972 dynSizes.push_back(Elt: plusHigh);
973 }
974 staticSizes.push_back(Elt: resultType.getDimSize(dim));
975 }
976
977 // Init tensor and fill it with padding.
978 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
979 padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
980 Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
981
982 // Generate a InsertSliceOp for copying the PadOp source.
983 auto sourceType = padOp.getSourceType();
984 // Compute size of source of tensor::PadOp.
985 SmallVector<OpFoldResult> srcSizes =
986 tensor::getMixedSizes(builder&: rewriter, loc: padOp.getLoc(), value: padOp.getSource());
987 // Strides of InsertSliceOp are all 1.
988 SmallVector<OpFoldResult> strides(sourceType.getRank(),
989 rewriter.getIndexAttr(1));
990 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
991 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
992 strides);
993
994 return success();
995}
996
997LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
998 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
999 if (!sliceOp.hasUnitStride())
1000 return failure();
1001
1002 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1003 if (!padOp)
1004 return failure();
1005
1006 bool zeroSliceGuard = true;
1007 if (controlFn) {
1008 if (std::optional<bool> control = controlFn(sliceOp))
1009 zeroSliceGuard = *control;
1010 else
1011 return failure();
1012 }
1013
1014 FailureOr<TilingResult> tilingResult =
1015 tensor::bubbleUpPadSlice(b&: rewriter, padOp: padOp, offsets: sliceOp.getMixedOffsets(),
1016 sizes: sliceOp.getMixedSizes(), generateZeroSliceGuard: zeroSliceGuard);
1017 if (failed(Result: tilingResult))
1018 return failure();
1019
1020 RankedTensorType sourceType = sliceOp.getSourceType();
1021 RankedTensorType resultType = sliceOp.getResultType();
1022
1023 // If the extract_slice is not rank-reduced, all shapes are static and the
1024 // data source is actually used. Rewrite into pad(extract_slice(x)).
1025 if (sourceType.getRank() == resultType.getRank()) {
1026 rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1027 return success();
1028 }
1029
1030 // Handle rank-reduced slice by creating another extract_slice op.
1031 Value rankReduced = tensor::createCanonicalRankReducingExtractSliceOp(
1032 b&: rewriter, loc: sliceOp.getLoc(), tensor: tilingResult->tiledValues[0], targetType: resultType);
1033
1034 rewriter.replaceOp(sliceOp, rankReduced);
1035 return success();
1036}
1037
1038/// If padding value is set, returns a tensor.pad Op for the source tensor,
1039/// with the output shape matching the output of `packOp`. Otherwise, returns
1040/// the source directly.
1041///
1042/// This method assumes that all outer dims for this pack Op are 1.
1043static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
1044 linalg::PackOp packOp) {
1045 Value input = packOp.getSource();
1046 if (!packOp.getPaddingValue()) {
1047 return input;
1048 }
1049
1050 assert(llvm::all_of(packOp.getAllOuterDims(),
1051 [](int64_t val) { return val == 1; }) &&
1052 "some outer dims are != 1");
1053
1054 Location loc = packOp.getLoc();
1055 ShapedType inputType = packOp.getSourceType();
1056 int64_t inputRank = inputType.getRank();
1057
1058 DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1059 packOp.getDimAndTileMapping();
1060
1061 // The sizes of dynamic tiles
1062 SmallVector<Value> dynamicTileSizes;
1063
1064 // Collect dims for the padded shape.
1065 SmallVector<int64_t> paddedShape;
1066 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1067 // 1. Non-tiled outer dims.
1068 // These dims should be 1 and we simply preserve them.
1069 if (!tileAndPosMapping.count(Val: dimIdx)) {
1070 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1071 assert(inputDimSize == 1 &&
1072 "with all outer dims == 1, this non-tiled input dim should be 1!");
1073 paddedShape.push_back(Elt: inputDimSize);
1074 continue;
1075 }
1076
1077 // 2. Tiled outer dims
1078 // As all outer dims == 1, it is safe to use the tile size for the padded
1079 // shape.
1080 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(Val: dimIdx);
1081
1082 // 2.1 Static tile sizes
1083 std::optional<int64_t> cstTileSize = getConstantIntValue(ofr: tileSizeForDim);
1084 if (cstTileSize.has_value()) {
1085 paddedShape.push_back(Elt: cstTileSize.value());
1086 continue;
1087 }
1088
1089 // 2.2 Dynamic tile sizes
1090 paddedShape.push_back(ShapedType::kDynamic);
1091
1092 // Get the value that holds the dynamic size.
1093 dynamicTileSizes.push_back(Elt: llvm::dyn_cast<Value>(Val&: tileSizeForDim));
1094 }
1095 auto resultType =
1096 RankedTensorType::get(paddedShape, inputType.getElementType());
1097 return tensor::createPadHighOp(resType: resultType, source: input, pad: packOp.getPaddingValue(),
1098 /*nofold=*/false, loc, builder,
1099 dynOutDims: dynamicTileSizes);
1100}
1101
1102// Normalizes a permutation on a higher rank space to its actual size, e.g.
1103// perm = [1, 4, 2]
1104// becomes
1105// norm = [0, 2, 1]
1106static SmallVector<int64_t>
1107getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) {
1108 constexpr int64_t kNonTiledMarker = -1;
1109 SmallVector<int64_t> vec(rank, kNonTiledMarker);
1110 for (auto [index, value] : llvm::enumerate(First&: perm))
1111 vec[value] = index;
1112 SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector(
1113 C&: vec, Pred: [&](int64_t v) { return v != kNonTiledMarker; });
1114 // This inverts the permutation in addition to normalizing so invert back.
1115 return invertPermutationVector(permutation: normalizedPerm);
1116}
1117
1118// Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1119// assuming rank reduction of unit outer dims.
1120static SmallVector<int64_t>
1121getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
1122 ArrayRef<int64_t> innerDimsPos,
1123 ArrayRef<int64_t> outerDimsPerm) {
1124 SmallVector<int64_t> rankReducedOuterDimsPerm;
1125 SmallVector<int64_t> outerDims;
1126 SmallVector<int64_t> innerDims;
1127 int64_t dim = 0;
1128 int64_t unpackedRank = shape.size();
1129 for (auto i : llvm::seq<unsigned>(Begin: 0, End: unpackedRank)) {
1130 if (llvm::is_contained(Range&: innerDimsPos, Element: i)) {
1131 innerDims.push_back(Elt: dim++);
1132 continue;
1133 }
1134 if (shape[i] == 1)
1135 continue;
1136 outerDims.push_back(Elt: dim++);
1137 if (!outerDimsPerm.empty())
1138 rankReducedOuterDimsPerm.push_back(Elt: outerDimsPerm[i]);
1139 }
1140
1141 // Get the position of the inner dims after permutation.
1142 SmallVector<int64_t> innerPerm =
1143 getPackUnpackNormalizedPerm(rank: unpackedRank, perm: innerDimsPos);
1144 applyPermutationToVector<int64_t>(inVec&: innerDims, permutation: innerPerm);
1145
1146 // Ditto for the outer dims.
1147 SmallVector<int64_t> perm = outerDims;
1148
1149 rankReducedOuterDimsPerm =
1150 getPackUnpackNormalizedPerm(rank: unpackedRank, perm: rankReducedOuterDimsPerm);
1151 if (!rankReducedOuterDimsPerm.empty())
1152 applyPermutationToVector<int64_t>(inVec&: perm, permutation: rankReducedOuterDimsPerm);
1153
1154 // The tile always ends up as the inner most dims after packing.
1155 perm.append(RHS: innerDims);
1156
1157 return perm;
1158}
1159
1160LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
1161 linalg::PackOp packOp, PatternRewriter &rewriter) const {
1162 // TODO: support the case that outer dimensions are not all 1s. A
1163 // tensor.expand_shape will be generated in this case.
1164 if (llvm::any_of(packOp.getAllOuterDims(),
1165 [](int64_t dim) { return dim != 1; })) {
1166 return rewriter.notifyMatchFailure(
1167 packOp, "not all outer dimensions of the result are 1s");
1168 }
1169
1170 Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1171 Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1172 Location loc = packOp.getLoc();
1173
1174 Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1175 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1176 packOp.getDimAndTileMapping();
1177 int64_t srcRank = packOp.getSourceRank();
1178 int64_t destRank = packOp.getDestRank();
1179 int64_t numTiles = destRank - srcRank;
1180
1181 if (!llvm::all_of(packOp.getInnerDimsPos(),
1182 [&srcRank, &numTiles](int64_t dimPos) {
1183 return dimPos >= (srcRank - numTiles - 1);
1184 }))
1185 return rewriter.notifyMatchFailure(
1186 packOp, "Attempting to tile non-trailing source dims!");
1187
1188 // 1. Extract the inner tile sizes.
1189 // Where possible, values are replaced with constant attributes (to match the
1190 // behaviour of `getPackOpSourceOrPaddedSource`).
1191 SmallVector<OpFoldResult> tileSizes;
1192 for (auto i : llvm::seq<unsigned>(0, srcRank)) {
1193 if (dimAndTileMapping.count(i)) {
1194 // Rather than taking the tile size as is, extact the actual constant
1195 // value Attribute where possible, e.g.:
1196 // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1197 auto [_, tileSize] =
1198 getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
1199 tileSizes.push_back(tileSize);
1200 }
1201 }
1202
1203 // 2. Transpose the input to match the inner tile order:
1204 // %init = tensor.empty()
1205 // %transposed_tile = linalg.transpose ins(%source_or_padded_source),
1206 // outs(%init)
1207 // Two assumptions are made:
1208 // 1. All outer dims are 1 - the corresponding transposition doesn't matter.
1209 // 2. Inner dims position correspond to the trailing `numTiles` dims.
1210 SmallVector<int64_t> tilesPermNormalized =
1211 getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
1212 SmallVector<int64_t> srcPermForTranspose;
1213 for (int64_t i = 0; i < (srcRank - numTiles); i++)
1214 srcPermForTranspose.push_back(Elt: i);
1215
1216 srcPermForTranspose.append(RHS: SmallVector<int64_t>(packOp.getInnerDimsPos()));
1217
1218 LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
1219 << "perm: " << llvm::interleaved(srcPermForTranspose)
1220 << "\n");
1221
1222 // 2.1 Create tensor.empty (init value for TransposeOp)
1223 SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
1224 oneIdxAttr);
1225 transShapeForEmptyOp.append(RHS: tileSizes);
1226
1227 applyPermutationToVector<OpFoldResult>(inVec&: transShapeForEmptyOp,
1228 permutation: srcPermForTranspose);
1229 Value empty = rewriter.create<tensor::EmptyOp>(
1230 loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
1231
1232 // 2.2 Create linalg.transpose
1233 auto transposedOp = rewriter.create<linalg::TransposeOp>(loc, input, empty,
1234 srcPermForTranspose);
1235
1236 // 3. Insert the inner tile to the destination:
1237 // %inserted_tile = tensor.insert_slice(%transposed_tile)
1238 SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1239 SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1240 // Outer dims are all 1s!
1241 SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1242 oneIdxAttr);
1243 SmallVector<int64_t> writeShape;
1244
1245 for (auto tileSize : packOp.getMixedTiles()) {
1246 auto [tileSizeStatic, tileSizeOfr] =
1247 getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
1248 writeSizes.push_back(tileSizeOfr);
1249 writeShape.push_back(tileSizeStatic);
1250 }
1251
1252 // 4. Replace tensor.packOp with tensor.insert_slice created above
1253 auto insert = rewriter.create<tensor::InsertSliceOp>(
1254 loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1255 writeSizes, writeStrides);
1256 rewriter.replaceOp(packOp, insert.getResult());
1257
1258 return success();
1259}
1260
1261LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
1262 linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1263 int64_t srcRank = unpackOp.getSourceRank();
1264 int64_t destRank = unpackOp.getDestRank();
1265 ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1266 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1267 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1268 [](int64_t dim) { return dim != 1; })) {
1269 return rewriter.notifyMatchFailure(
1270 unpackOp,
1271 "require the tiled outer dimensions of the result are all 1s");
1272 }
1273
1274 // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1275 // %extracted_tile = tensor.extract_slice(%unpack_op_input)
1276 Location loc = unpackOp.getLoc();
1277 Value source = unpackOp.getSource();
1278 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1279 unpackOp.getDimAndTileMapping();
1280 Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1281 Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1282
1283 // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1284 // dims:
1285 // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1286 SmallVector<int64_t> readShapeForExtractSlice;
1287 // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1288 // outer-tiled-dims being all 1), this will be
1289 // [ outer-untiled-dims, tile-sizes ]
1290 SmallVector<OpFoldResult> extractSliceSizes;
1291 // The offset and strides attributes for ExtractSliceOp.
1292 SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
1293 SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
1294
1295 // Shape for EmptyOp that's used as the init value for TransposeOp below.
1296 // This should be:
1297 // [ outer-untiled-dims, tile-sizes ]
1298 // However, skip unit dims - TransposeOp (below) applies rank-reduced
1299 // permutation.
1300 SmallVector<OpFoldResult> shapeForEmptyOp;
1301
1302 for (auto i : llvm::seq<unsigned>(0, destRank)) {
1303 // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1304 //
1305 // As all outer tiled dims are 1, so the corresponding
1306 // slice size to read will also 1. As this will be rank-reducing "extract
1307 // slice" (i.e. the unit dims will be "collapsed"), there's no need to
1308 // update:
1309 // * the output shape for ExtractSliceOp, nor
1310 // * the shape for EmptyOp.
1311 if (dimAndTileMapping.count(i)) {
1312 extractSliceSizes.push_back(oneIdxAttr);
1313 continue;
1314 }
1315
1316 // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1317 // outer-untiled-dims
1318 if (ShapedType::isDynamic(srcShape[i])) {
1319 OpFoldResult dynamicDim =
1320 rewriter.create<tensor::DimOp>(loc, source, i).getResult();
1321 extractSliceSizes.push_back(dynamicDim);
1322 shapeForEmptyOp.push_back(dynamicDim);
1323 } else {
1324 extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1325 if (srcShape[i] != 1)
1326 shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1327 }
1328 // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1329 // into account rank-reducing)
1330 if (srcShape[i] != 1) {
1331 readShapeForExtractSlice.push_back(srcShape[i]);
1332 }
1333 }
1334 // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1335 // shape for EmptyOp.
1336 auto mixedTiles = unpackOp.getMixedTiles();
1337 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1338 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1339
1340 // Explicitly create the type for extract_slice op because the inner tile
1341 // size could be 1. We want to represent the whole inner tile in this case.
1342 auto tileShape = srcShape.drop_front(N: destRank);
1343 // Append the inner tile shape to the permuted and rank-reduced outer shape.
1344 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1345 Type elemType = unpackOp.getSourceType().getElementType();
1346 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1347 Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
1348 loc, readType, unpackOp.getSource(), extractSliceOffsets,
1349 extractSliceSizes, extractSliceStrides);
1350
1351 // 2. Transpose the tile to match the outer corresponding tile order.
1352 SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
1353 srcShape.take_front(N: destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1354 // Unpack is a transition out of packed space so we invert the permutation.
1355 perm = invertPermutationVector(permutation: perm);
1356 applyPermutationToVector<OpFoldResult>(inVec&: shapeForEmptyOp, permutation: perm);
1357
1358 Value empty =
1359 rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
1360 auto transposedOp =
1361 rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1362
1363 // 3. Handle in-complete tiles if needed. It truncates trailing data from the
1364 // transposed tile.
1365 int numLoops = shapeForEmptyOp.size();
1366 SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
1367 SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
1368 SmallVector<OpFoldResult> tileSizes;
1369 ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1370 for (auto i : llvm::seq<unsigned>(0, destRank)) {
1371 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1372 tileSizes.push_back(
1373 tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
1374 }
1375
1376 auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
1377 loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1378
1379 // 4. Insert the result to the destination tensor.
1380 SmallVector<OpFoldResult> writeSizes;
1381 SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1382 SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1383 for (int i = 0, idx = 0; i < destRank; ++i) {
1384 if (dimAndTileMapping.count(Val: i) || destShape[i] != 1)
1385 writeSizes.push_back(Elt: tileSizes[idx++]);
1386 else
1387 writeSizes.push_back(Elt: oneIdxAttr);
1388 }
1389 auto insert = rewriter.create<tensor::InsertSliceOp>(
1390 loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1391 writeStrides);
1392 rewriter.replaceOp(unpackOp, insert.getResult());
1393
1394 return success();
1395}
1396
1397// The following are patterns for downscaling convolution ops with size-1
1398// window dimensions.
1399//
1400// Note that we'd eventually want to write such transformations in a generic
1401// way, e.g., converting to linalg.generic, removing the size-1 dimensions,
1402// and then turning back to named ops. But for now it's fine to have a few
1403// patterns matching special ops to get started.
1404
1405template <typename Conv2DOp, typename Conv1DOp>
1406FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
1407 returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
1408 if (convOp.hasPureBufferSemantics())
1409 return failure(); // To be implemented.
1410
1411 Value input = convOp.getInputs().front();
1412 Value kernel = convOp.getInputs().back();
1413 Value output = convOp.getOutputs().front();
1414
1415 auto inputType = dyn_cast<RankedTensorType>(input.getType());
1416 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1417 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1418
1419 auto kernelShape = kernelType.getShape();
1420 auto outputShape = outputType.getShape();
1421
1422 // Get domain indices based on conv2D layout.
1423 auto [khIndex, kwIndex, ohIndex, owIndex] =
1424 TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>(
1425 convOp)
1426 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1427 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1428 })
1429 .Case([&](linalg::Conv2DNchwFchwOp op) {
1430 return std::make_tuple(args: 2, args: 3, args: 2, args: 3);
1431 })
1432 .Case([&](linalg::PoolingNhwcSumOp op) {
1433 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1434 })
1435 .Case([&](linalg::PoolingNchwSumOp op) {
1436 return std::make_tuple(args: 0, args: 1, args: 2, args: 3);
1437 })
1438 .Case([&](linalg::PoolingNhwcMaxOp op) {
1439 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1440 })
1441 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1442 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1443 })
1444 .Case([&](linalg::PoolingNhwcMinOp op) {
1445 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1446 })
1447 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1448 return std::make_tuple(args: 0, args: 1, args: 1, args: 2);
1449 })
1450 .Case([&](linalg::PoolingNchwMaxOp op) {
1451 return std::make_tuple(args: 0, args: 1, args: 2, args: 3);
1452 })
1453 .Default([&](Operation *op) {
1454 llvm_unreachable("unexpected conv2d/pool2d operation.");
1455 return std::make_tuple(args: 0, args: 0, args: 0, args: 0);
1456 });
1457
1458 // Only handle the case where at least one of the window dimensions is
1459 // of size 1. Other cases can rely on tiling to reduce to such cases.
1460 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1461 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1462 bool removeH = (khSize == 1 && ohSize == 1);
1463 bool removeW = (kwSize == 1 && owSize == 1);
1464 if (!removeH && !removeW)
1465 return failure();
1466
1467 // Get new shapes and types for all operands by removing the size-1
1468 // dimension.
1469 using RTTBuilder = RankedTensorType::Builder;
1470 RankedTensorType newInputType =
1471 RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
1472 RankedTensorType newKernelType =
1473 RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
1474 RankedTensorType newOutputType =
1475 RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
1476
1477 // Rank-reduce operands.
1478 Location loc = convOp.getLoc();
1479 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1480 b&: rewriter, loc, tensor: input, targetType: newInputType);
1481 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1482 b&: rewriter, loc, tensor: kernel, targetType: newKernelType);
1483 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1484 b&: rewriter, loc, tensor: output, targetType: newOutputType);
1485
1486 // Rank-reduce strides and dilations too.
1487 // TODO: dropDim 1-liner helper.
1488 auto strides =
1489 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1490 strides.erase(strides.begin() + (removeH ? 0 : 1));
1491 auto stridesAttr = rewriter.getI64VectorAttr(values: strides);
1492
1493 auto dilations =
1494 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1495 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1496 auto dilationsAttr = rewriter.getI64VectorAttr(values: dilations);
1497
1498 auto conv1DOp = rewriter.create<Conv1DOp>(
1499 loc, newOutputType, ValueRange{newInput, newKernel},
1500 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1501
1502 // Insert back.
1503 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1504 b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output);
1505 rewriter.replaceOp(convOp, inserted);
1506
1507 return conv1DOp;
1508}
1509
1510template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1511 Conv1DNwcWcfOp>;
1512template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1513 Conv1DNcwFcwOp>;
1514template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
1515 PoolingNwcSumOp>;
1516template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
1517 PoolingNcwSumOp>;
1518template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
1519 PoolingNwcMaxOp>;
1520template struct linalg::DownscaleSizeOneWindowed2DConvolution<
1521 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1522template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
1523 PoolingNwcMinOp>;
1524template struct linalg::DownscaleSizeOneWindowed2DConvolution<
1525 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1526template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
1527 PoolingNcwMaxOp>;
1528
1529FailureOr<DepthwiseConv1DNwcWcOp>
1530DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
1531 DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
1532 if (convOp.hasPureBufferSemantics())
1533 return failure(); // To be implemented.
1534
1535 Value input = convOp.getInputs().front();
1536 Value kernel = convOp.getInputs().back();
1537 Value output = convOp.getOutputs().front();
1538
1539 auto inputType = dyn_cast<RankedTensorType>(input.getType());
1540 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1541 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1542
1543 auto kernelShape = kernelType.getShape();
1544 auto outputShape = outputType.getShape();
1545
1546 // Only handle the case where at least one of the window dimensions is
1547 // of size 1. Other cases can rely on tiling to reduce to such cases.
1548 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1549 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1550 bool removeH = (khSize == 1 && ohSize == 1);
1551 bool removeW = (kwSize == 1 && owSize == 1);
1552 if (!removeH && !removeW)
1553 return failure();
1554
1555 // Get new shapes and types for all operands by removing the size-1
1556 // dimension.
1557 using RTTBuilder = RankedTensorType::Builder;
1558 RankedTensorType newInputType =
1559 RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1560 RankedTensorType newKernelType =
1561 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1562 RankedTensorType newOutputType =
1563 RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1564
1565 // Rank-reduce operands.
1566 Location loc = convOp.getLoc();
1567 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1568 b&: rewriter, loc, tensor: input, targetType: newInputType);
1569 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1570 b&: rewriter, loc, tensor: kernel, targetType: newKernelType);
1571 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1572 b&: rewriter, loc, tensor: output, targetType: newOutputType);
1573
1574 // Rank-reduce strides and dilations too.
1575 // TODO: dropDim 1-liner helper.
1576 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1577 strides.erase(strides.begin() + (removeH ? 0 : 1));
1578 auto stridesAttr = rewriter.getI64VectorAttr(values: strides);
1579
1580 auto dilations =
1581 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1582 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1583 auto dilationsAttr = rewriter.getI64VectorAttr(values: dilations);
1584
1585 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1586 loc, newOutputType, ValueRange{newInput, newKernel},
1587 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1588
1589 // Insert back.
1590 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1591 b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output);
1592 rewriter.replaceOp(convOp, inserted);
1593
1594 return conv1DOp;
1595}
1596
1597FailureOr<Conv1DOp>
1598DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
1599 PatternRewriter &rewriter) const {
1600 if (convOp.hasPureBufferSemantics())
1601 return failure(); // To be implemented.
1602
1603 Value input = convOp.getInputs().front();
1604 Value kernel = convOp.getInputs().back();
1605 Value output = convOp.getOutputs().front();
1606
1607 auto inputType = dyn_cast<RankedTensorType>(input.getType());
1608 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1609 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1610
1611 auto kernelShape = kernelType.getShape();
1612 auto outputShape = outputType.getShape();
1613
1614 // Only handle the case where at least one of the window dimensions is
1615 // of size 1. Other cases can rely on tiling to reduce to such cases.
1616 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1617 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1618 bool removeH = (khSize == 1 && ohSize == 1);
1619 bool removeW = (kwSize == 1 && owSize == 1);
1620 if (!removeH && !removeW)
1621 return failure();
1622
1623 // Get new shapes and types for all operands by removing the size-1
1624 // dimension.
1625 using RTTBuilder = RankedTensorType::Builder;
1626 RankedTensorType newInputType =
1627 RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
1628 RankedTensorType newKernelType =
1629 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1630 RankedTensorType newOutputType =
1631 RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
1632
1633 // Rank-reduce operands.
1634 Location loc = convOp.getLoc();
1635 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1636 b&: rewriter, loc, tensor: input, targetType: newInputType);
1637 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1638 b&: rewriter, loc, tensor: kernel, targetType: newKernelType);
1639 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1640 b&: rewriter, loc, tensor: output, targetType: newOutputType);
1641
1642 auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType,
1643 ValueRange{newInput, newKernel},
1644 ValueRange{newOutput});
1645
1646 // Insert back.
1647 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1648 b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output);
1649 rewriter.replaceOp(convOp, inserted);
1650
1651 return conv1DOp;
1652}
1653
1654void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
1655 PatternBenefit benefit) {
1656 patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
1657 Conv1DNwcWcfOp>,
1658 DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
1659 Conv1DNcwFcwOp>,
1660 DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>(
1661 patterns.getContext(), benefit);
1662 patterns.add<
1663 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
1664 DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
1665 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>,
1666 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
1667 PoolingNwcMaxUnsignedOp>,
1668 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>,
1669 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
1670 PoolingNwcMinUnsignedOp>,
1671 DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
1672 patterns.getContext(), benefit);
1673}
1674
1675void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
1676 patterns.add<DecomposeOuterUnitDimsPackOpPattern>(arg: patterns.getContext());
1677 patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(arg: patterns.getContext());
1678}
1679
1680void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
1681 patterns.add<DecomposePadOpPattern>(arg: patterns.getContext());
1682}
1683

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp