1//===- FusePadOpWithLinalgProducer.cpp ---- Fuse pad with linalg producer -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns that fuses a linalg.generic -> tensor.pad op
10// chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice
11// op chain.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18
19using namespace mlir;
20
21namespace {
22
23/// A sequence of operations
24///
25/// ```mlir
26/// %0 = linalg. ...
27/// %1 = tensor.pad %0 ...
28/// ```
29///
30/// can be replaced with
31///
32/// ```mlir
33/// %0 = linalg.fill
34/// %1 = tensor.extract_slice %0 ...
35/// %2 = linalg. .... outs(..., %1, ....) ....
36/// %3 = tensor.insert_slice %2 into %1 ...
37/// ```
38///
39/// if the `linalg.generic` has all parallel iterator types.
40struct FusePadOp : OpRewritePattern<tensor::PadOp> {
41 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
42
43 LogicalResult matchAndRewrite(tensor::PadOp padOp,
44 PatternRewriter &rewriter) const override {
45 // Only works on padding op that sets the padded value to a constant.
46 Value padValue = padOp.getConstantPaddingValue();
47 if (!padValue)
48 return rewriter.notifyMatchFailure(arg&: padOp, msg: "non constant padding");
49
50 // This pattern could work for any Linalg op. For now restrict it to generic
51 // ops.
52 Value source = padOp.getSource();
53 auto linalgOp = source.getDefiningOp<linalg::GenericOp>();
54 if (!linalgOp) {
55 return rewriter.notifyMatchFailure(
56 arg&: padOp, msg: "expected source to be linalg.generic op");
57 }
58 // All iterator types need to be parallel.
59 if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) {
60 return rewriter.notifyMatchFailure(
61 arg&: padOp, msg: "only supported for ops with all parallel iterator types");
62 }
63 ReifiedRankedShapedTypeDims resultShape;
64 if (failed(Result: reifyResultShapes(b&: rewriter, op: padOp, reifiedReturnShapes&: resultShape)) ||
65 resultShape.size() != 1) {
66 return rewriter.notifyMatchFailure(
67 arg&: padOp, msg: "failed to get shape of pad op result");
68 }
69
70 Location loc = padOp.getLoc();
71
72 // Create the tensor of same size as output of the pad op.
73 RankedTensorType padResultType = padOp.getResultType();
74 auto resultSizes = resultShape[0];
75 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
76 location: loc, args&: resultSizes, args: padResultType.getElementType());
77
78 // Fill the tensor with the pad value.
79 // TODO: There is an option to fill only the boundaries. For now just
80 // filling the whole tensor.
81 auto fillTensor =
82 rewriter.create<linalg::FillOp>(location: loc, args&: padValue, args: emptyTensor.getResult());
83
84 // Construct a slice of the fill result that is to be replaced with the
85 // result of the generic op. The low pad values are the offsets, the size of
86 // the source is the size of the slice.
87 // TODO: This insert/extract could be potentially made a utility method.
88 unsigned resultNumber = cast<OpResult>(Val&: source).getResultNumber();
89 SmallVector<OpFoldResult> offsets = padOp.getMixedLowPad();
90 SmallVector<OpFoldResult> sizes;
91 sizes.reserve(N: offsets.size());
92 for (const auto &shape :
93 llvm::enumerate(First: cast<RankedTensorType>(Val: source.getType()).getShape())) {
94 if (ShapedType::isDynamic(dValue: shape.value())) {
95 sizes.push_back(
96 Elt: rewriter.create<tensor::DimOp>(location: loc, args&: source, args: shape.index())
97 .getResult());
98 } else {
99 sizes.push_back(Elt: rewriter.getIndexAttr(value: shape.value()));
100 }
101 }
102 SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(value: 1));
103 auto slice = rewriter.create<tensor::ExtractSliceOp>(
104 location: loc, args: fillTensor.getResult(i: 0), args&: offsets, args&: sizes, args&: strides);
105
106 // Clone the generic op.
107 auto clonedOp =
108 cast<linalg::GenericOp>(Val: rewriter.clone(op&: *linalgOp.getOperation()));
109 clonedOp.setDpsInitOperand(i: resultNumber, value: slice.getResult());
110
111 // Insert it back into the result of the fill.
112 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
113 op: padOp, args: clonedOp.getResult(i: resultNumber), args: fillTensor.getResult(i: 0),
114 args&: offsets, args&: sizes, args&: strides);
115 return success();
116 }
117};
118} // namespace
119
120void mlir::linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(
121 RewritePatternSet &patterns) {
122 patterns.add<FusePadOp>(arg: patterns.getContext());
123}
124

source code of mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp