1//===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9#include "mlir/Dialect/Tensor/IR/Tensor.h"
10#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11#include "mlir/Dialect/Utils/IndexingUtils.h"
12#include "mlir/IR/Matchers.h"
13#include "mlir/IR/PatternMatch.h"
14
15#include "llvm/ADT/TypeSwitch.h"
16
17using namespace mlir;
18using namespace mlir::tensor;
19
20namespace {
21
22/// Rewrite tensor.generate with arith.constant if the yielded value is a
23/// constant and the tensor type is static.
24struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
25 using OpRewritePattern<GenerateOp>::OpRewritePattern;
26
27 LogicalResult matchAndRewrite(GenerateOp generateOp,
28 PatternRewriter &rewriter) const override {
29 auto tensorType =
30 llvm::cast<RankedTensorType>(generateOp.getResult().getType());
31 if (!tensorType.hasStaticShape())
32 return failure();
33 auto terminatorOp =
34 cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
35 Attribute attr;
36 if (!matchPattern(terminatorOp.getValue(), m_Constant(bind_value: &attr)))
37 return failure();
38 Operation *constantOp =
39 rewriter.getContext()
40 ->getLoadedDialect<TensorDialect>()
41 ->materializeConstant(rewriter,
42 DenseElementsAttr::get(tensorType, attr),
43 tensorType, generateOp->getLoc());
44 if (!constantOp)
45 return failure();
46 rewriter.replaceOp(generateOp, constantOp->getResults());
47 return success();
48 }
49};
50
51/// Transform a linear index from one indexing space to another given:
52///
53/// - the shape of the source indexing space,
54/// - the strides of the target indexing space,
55/// - a linear index into the source indexing space.
56///
57/// This function is logically a sequence of linearize/delinearize over
58/// different bases but avoids allocating intermediate SmallVectors.
59int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
60 ArrayRef<int64_t> outputStrides,
61 int64_t srcLinearIndex) {
62 assert(inputShape.size() == outputStrides.size());
63
64 int64_t dstLinearIndex = 0;
65
66 for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
67 // Compute the index into the current dimension of the source tensor.
68 // `quotient` is the remaining linear index after accounting for the
69 // current dimension.
70 //
71 // `remainder` is the index into the source tensor for the current
72 // dimension.
73 auto [quotient, remainder] = std::div(i: srcLinearIndex, j: inputShape[dim]);
74
75 srcLinearIndex = quotient;
76
77 // Add the contribution of the current dimension to the output using the
78 // permutation map.
79 dstLinearIndex += outputStrides[dim] * remainder;
80 }
81
82 return dstLinearIndex;
83}
84
85template <typename ElemType, typename AttrType>
86Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
87 DenseElementsAttr input, AttrType padValue,
88 ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
89 auto inputValues = input.tryGetValues<ElemType>();
90 if (failed(inputValues))
91 return nullptr;
92
93 auto oldShape = input.getType().getShape();
94
95 // Compute the output shape of the new value.
96 auto newShape =
97 llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
98 [](std::tuple<int64_t, int64_t, int64_t> pack) {
99 auto [old, low, high] = pack;
100 return old + low + high;
101 });
102
103 int64_t outputSize = computeProduct(newShape);
104
105 // Fully initialize the vector with the padding value.
106 // The non-padded area will then be copied.
107 SmallVector<ElemType> values(outputSize, padValue.getValue());
108
109 // Strides for input and output are used to transform between the indexing
110 // space of the input and output tensors.
111 SmallVector<int64_t> outputStrides = computeStrides(newShape);
112
113 // The contribution of the low padding to the offset in the output tensor.
114 // This is the starting position of the source tensor within the padding
115 // tensor.
116 int64_t startingOffset = linearize(offsets: padLow, basis: outputStrides);
117
118 // Copy values from the input tensor to the corresponding sub-region
119 // of the output tensor.
120 for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
121 auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
122 values[outputIndex + startingOffset] = inputValue;
123 }
124
125 // Create an attribute for the folded value.
126 auto newType = input.getType().clone(newShape);
127 auto newAttr = DenseElementsAttr::get(newType, values);
128
129 Operation *constantOp =
130 rewriter.getContext()
131 ->getLoadedDialect<TensorDialect>()
132 ->materializeConstant(rewriter, newAttr, newType, loc);
133
134 return constantOp ? constantOp->getResult(idx: 0) : nullptr;
135}
136
137struct PadOpToConstant final : public OpRewritePattern<PadOp> {
138
139 PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn,
140 PatternBenefit benefit = 1)
141 : OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {}
142
143 LogicalResult matchAndRewrite(PadOp padTensorOp,
144 PatternRewriter &rewriter) const override {
145 if (padTensorOp.getNofold())
146 return rewriter.notifyMatchFailure(
147 padTensorOp, "refusing to fold nofold pad operation");
148
149 TypedValue<RankedTensorType> input = padTensorOp.getSource();
150 RankedTensorType resultType = padTensorOp.getResult().getType();
151
152 DenseElementsAttr inputAttr = nullptr;
153 if (!matchPattern(value: input, pattern: m_Constant(bind_value: &inputAttr)))
154 return failure();
155
156 Value paddingValue = padTensorOp.getConstantPaddingValue();
157
158 // Extract the constant value used for padding or bail out.
159 Attribute paddingAttr = nullptr;
160 if (!paddingValue || !matchPattern(value: paddingValue, pattern: m_Constant(bind_value: &paddingAttr)))
161 return rewriter.notifyMatchFailure(padTensorOp,
162 "unable to get constant value");
163
164 // Try to extract the constant values of the low and high padding.
165 auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
166 auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());
167
168 // If the padding cannot be extracted, bail out.
169 if (!lowPad || !highPad)
170 return rewriter.notifyMatchFailure(padTensorOp,
171 "unable to extract constant padding");
172
173 // We have a potential candidate, consult the control function to
174 // determine if the op should fold.
175 if (!controlFn(&padTensorOp.getSourceMutable()))
176 return rewriter.notifyMatchFailure(padTensorOp,
177 "not folding due to cost function");
178
179 Location loc = padTensorOp.getLoc();
180
181 // Try constant folding the supported cases of integer and float values.
182 Value newOp =
183 llvm::TypeSwitch<Attribute, Value>(paddingAttr)
184 .Case(caseFn: [&](FloatAttr floatAttr) {
185 return constantFoldPadOp<llvm::APFloat>(
186 rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
187 })
188 .Case(caseFn: [&](IntegerAttr integerAttr) {
189 return constantFoldPadOp<llvm::APInt>(
190 rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
191 })
192 .Default(defaultResult: Value());
193
194 if (!newOp)
195 return rewriter.notifyMatchFailure(padTensorOp,
196 "tensor type not supported");
197
198 if (newOp.getType() != resultType)
199 newOp = rewriter.create<tensor::CastOp>(loc, resultType, newOp);
200
201 rewriter.replaceOp(padTensorOp, newOp);
202 return success();
203 }
204
205private:
206 ControlFoldFn controlFn;
207};
208
209} // namespace
210
211void mlir::tensor::populateRewriteAsConstantPatterns(
212 RewritePatternSet &patterns, const ControlFoldFn &controlFn) {
213 patterns.add<GenerateToConstant>(arg: patterns.getContext());
214
215 patterns.add<PadOpToConstant>(arg: patterns.getContext(), args: controlFn);
216}
217

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp