1 | //===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
10 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
11 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
12 | #include "mlir/IR/Matchers.h" |
13 | #include "mlir/IR/PatternMatch.h" |
14 | |
15 | #include "llvm/ADT/TypeSwitch.h" |
16 | |
17 | using namespace mlir; |
18 | using namespace mlir::tensor; |
19 | |
20 | namespace { |
21 | |
22 | /// Rewrite tensor.generate with arith.constant if the yielded value is a |
23 | /// constant and the tensor type is static. |
24 | struct GenerateToConstant : public OpRewritePattern<GenerateOp> { |
25 | using OpRewritePattern<GenerateOp>::OpRewritePattern; |
26 | |
27 | LogicalResult matchAndRewrite(GenerateOp generateOp, |
28 | PatternRewriter &rewriter) const override { |
29 | auto tensorType = |
30 | llvm::cast<RankedTensorType>(generateOp.getResult().getType()); |
31 | if (!tensorType.hasStaticShape()) |
32 | return failure(); |
33 | auto terminatorOp = |
34 | cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator()); |
35 | Attribute attr; |
36 | if (!matchPattern(terminatorOp.getValue(), m_Constant(bind_value: &attr))) |
37 | return failure(); |
38 | Operation *constantOp = |
39 | rewriter.getContext() |
40 | ->getLoadedDialect<TensorDialect>() |
41 | ->materializeConstant(rewriter, |
42 | DenseElementsAttr::get(tensorType, attr), |
43 | tensorType, generateOp->getLoc()); |
44 | if (!constantOp) |
45 | return failure(); |
46 | rewriter.replaceOp(generateOp, constantOp->getResults()); |
47 | return success(); |
48 | } |
49 | }; |
50 | |
51 | /// Transform a linear index from one indexing space to another given: |
52 | /// |
53 | /// - the shape of the source indexing space, |
54 | /// - the strides of the target indexing space, |
55 | /// - a linear index into the source indexing space. |
56 | /// |
57 | /// This function is logically a sequence of linearize/delinearize over |
58 | /// different bases but avoids allocating intermediate SmallVectors. |
59 | int64_t transformIndexSpace(ArrayRef<int64_t> inputShape, |
60 | ArrayRef<int64_t> outputStrides, |
61 | int64_t srcLinearIndex) { |
62 | assert(inputShape.size() == outputStrides.size()); |
63 | |
64 | int64_t dstLinearIndex = 0; |
65 | |
66 | for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) { |
67 | // Compute the index into the current dimension of the source tensor. |
68 | // `quotient` is the remaining linear index after accounting for the |
69 | // current dimension. |
70 | // |
71 | // `remainder` is the index into the source tensor for the current |
72 | // dimension. |
73 | auto [quotient, remainder] = std::div(i: srcLinearIndex, j: inputShape[dim]); |
74 | |
75 | srcLinearIndex = quotient; |
76 | |
77 | // Add the contribution of the current dimension to the output using the |
78 | // permutation map. |
79 | dstLinearIndex += outputStrides[dim] * remainder; |
80 | } |
81 | |
82 | return dstLinearIndex; |
83 | } |
84 | |
85 | template <typename ElemType, typename AttrType> |
86 | Value constantFoldPadOp(PatternRewriter &rewriter, Location loc, |
87 | DenseElementsAttr input, AttrType padValue, |
88 | ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) { |
89 | auto inputValues = input.tryGetValues<ElemType>(); |
90 | if (failed(inputValues)) |
91 | return nullptr; |
92 | |
93 | auto oldShape = input.getType().getShape(); |
94 | |
95 | // Compute the output shape of the new value. |
96 | auto newShape = |
97 | llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh), |
98 | [](std::tuple<int64_t, int64_t, int64_t> pack) { |
99 | auto [old, low, high] = pack; |
100 | return old + low + high; |
101 | }); |
102 | |
103 | int64_t outputSize = computeProduct(newShape); |
104 | |
105 | // Fully initialize the vector with the padding value. |
106 | // The non-padded area will then be copied. |
107 | SmallVector<ElemType> values(outputSize, padValue.getValue()); |
108 | |
109 | // Strides for input and output are used to transform between the indexing |
110 | // space of the input and output tensors. |
111 | SmallVector<int64_t> outputStrides = computeStrides(newShape); |
112 | |
113 | // The contribution of the low padding to the offset in the output tensor. |
114 | // This is the starting position of the source tensor within the padding |
115 | // tensor. |
116 | int64_t startingOffset = linearize(offsets: padLow, basis: outputStrides); |
117 | |
118 | // Copy values from the input tensor to the corresponding sub-region |
119 | // of the output tensor. |
120 | for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) { |
121 | auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex); |
122 | values[outputIndex + startingOffset] = inputValue; |
123 | } |
124 | |
125 | // Create an attribute for the folded value. |
126 | auto newType = input.getType().clone(newShape); |
127 | auto newAttr = DenseElementsAttr::get(newType, values); |
128 | |
129 | Operation *constantOp = |
130 | rewriter.getContext() |
131 | ->getLoadedDialect<TensorDialect>() |
132 | ->materializeConstant(rewriter, newAttr, newType, loc); |
133 | |
134 | return constantOp ? constantOp->getResult(idx: 0) : nullptr; |
135 | } |
136 | |
137 | struct PadOpToConstant final : public OpRewritePattern<PadOp> { |
138 | |
139 | PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn, |
140 | PatternBenefit benefit = 1) |
141 | : OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {} |
142 | |
143 | LogicalResult matchAndRewrite(PadOp padTensorOp, |
144 | PatternRewriter &rewriter) const override { |
145 | if (padTensorOp.getNofold()) |
146 | return rewriter.notifyMatchFailure( |
147 | padTensorOp, "refusing to fold nofold pad operation" ); |
148 | |
149 | TypedValue<RankedTensorType> input = padTensorOp.getSource(); |
150 | RankedTensorType resultType = padTensorOp.getResult().getType(); |
151 | |
152 | DenseElementsAttr inputAttr = nullptr; |
153 | if (!matchPattern(value: input, pattern: m_Constant(bind_value: &inputAttr))) |
154 | return failure(); |
155 | |
156 | Value paddingValue = padTensorOp.getConstantPaddingValue(); |
157 | |
158 | // Extract the constant value used for padding or bail out. |
159 | Attribute paddingAttr = nullptr; |
160 | if (!paddingValue || !matchPattern(value: paddingValue, pattern: m_Constant(bind_value: &paddingAttr))) |
161 | return rewriter.notifyMatchFailure(padTensorOp, |
162 | "unable to get constant value" ); |
163 | |
164 | // Try to extract the constant values of the low and high padding. |
165 | auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad()); |
166 | auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad()); |
167 | |
168 | // If the padding cannot be extracted, bail out. |
169 | if (!lowPad || !highPad) |
170 | return rewriter.notifyMatchFailure(padTensorOp, |
171 | "unable to extract constant padding" ); |
172 | |
173 | // We have a potential candidate, consult the control function to |
174 | // determine if the op should fold. |
175 | if (!controlFn(&padTensorOp.getSourceMutable())) |
176 | return rewriter.notifyMatchFailure(padTensorOp, |
177 | "not folding due to cost function" ); |
178 | |
179 | Location loc = padTensorOp.getLoc(); |
180 | |
181 | // Try constant folding the supported cases of integer and float values. |
182 | Value newOp = |
183 | llvm::TypeSwitch<Attribute, Value>(paddingAttr) |
184 | .Case(caseFn: [&](FloatAttr floatAttr) { |
185 | return constantFoldPadOp<llvm::APFloat>( |
186 | rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad); |
187 | }) |
188 | .Case(caseFn: [&](IntegerAttr integerAttr) { |
189 | return constantFoldPadOp<llvm::APInt>( |
190 | rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad); |
191 | }) |
192 | .Default(defaultResult: Value()); |
193 | |
194 | if (!newOp) |
195 | return rewriter.notifyMatchFailure(padTensorOp, |
196 | "tensor type not supported" ); |
197 | |
198 | if (newOp.getType() != resultType) |
199 | newOp = rewriter.create<tensor::CastOp>(loc, resultType, newOp); |
200 | |
201 | rewriter.replaceOp(padTensorOp, newOp); |
202 | return success(); |
203 | } |
204 | |
205 | private: |
206 | ControlFoldFn controlFn; |
207 | }; |
208 | |
209 | } // namespace |
210 | |
211 | void mlir::tensor::populateRewriteAsConstantPatterns( |
212 | RewritePatternSet &patterns, const ControlFoldFn &controlFn) { |
213 | patterns.add<GenerateToConstant>(arg: patterns.getContext()); |
214 | |
215 | patterns.add<PadOpToConstant>(arg: patterns.getContext(), args: controlFn); |
216 | } |
217 | |