1//===- TosaFolders.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Fold TOSA operations
10//
11//===----------------------------------------------------------------------===//
12
13#include <functional>
14#include <numeric>
15
16#include "mlir/Dialect/Tosa/IR/TosaOps.h"
17#include "mlir/Dialect/Tosa/Transforms/Passes.h"
18#include "mlir/Dialect/Utils/IndexingUtils.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Support/LogicalResult.h"
24#include "llvm/ADT/APFloat.h"
25#include "llvm/ADT/FloatingPointMode.h"
26#include "llvm/ADT/SmallVector.h"
27
28using namespace mlir;
29using namespace mlir::tosa;
30
31namespace {
32
33/// Apply the given transformation \p toApply to every element of the tensor to
34/// be transformed \p toTransform.
35///
36/// Elements of \p toTransform are extracted as \p SrcValueType.
37///
38/// \returns A tensor with the same size as \p toTransform, containing
39/// \p TargetValueType values of type \p TargetType.
40template <class SrcValType, class TargetValType, class TargetType>
41DenseElementsAttr applyElementWise(
42 const DenseElementsAttr &toTransform,
43 const std::function<TargetValType(const SrcValType &)> &toApply,
44 TargetType targetType) {
45 SmallVector<TargetValType> transformedValues;
46 // We already know the amount of values we will insert, reserve space for
47 // all of them to avoid dynamic resizing
48 transformedValues.reserve(toTransform.getNumElements());
49 for (auto val : toTransform.getValues<SrcValType>()) {
50 auto transformedVal = toApply(val);
51 transformedValues.push_back(transformedVal);
52 }
53
54 // Make sure that the output tensor has the expected output type
55 auto inShape = toTransform.getType();
56 auto outTy = inShape.cloneWith({}, targetType);
57
58 return DenseElementsAttr::get(outTy, transformedValues);
59}
60
61template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
62 const DenseElementsAttr &toTransform,
63 const std::function<APFloat(const APFloat &)> &toApply,
64 FloatType targetType);
65
66/// Function that checks if the type contained in \p toCheck is float.
67LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
68 PatternRewriter &rewriter) {
69 if (isa<FloatType>(Val: toCheck.getType().getElementType())) {
70 return success();
71 }
72 return rewriter.notifyMatchFailure(location,
73 "Unexpected input tensor type: the "
74 "TOSA spec only allows floats");
75}
76
77/// Function that checks if \p toCheck is a dense TOSA constant tensor.
78LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
79 TosaOp location,
80 PatternRewriter &rewriter) {
81 // Check whether the tensor is constant and dense
82 // TODO We currently ensure the tensor is dense by using the correct type for
83 // the bind_value, however we do not actually need this value. It would be
84 // nicer to only have a check here.
85 DenseElementsAttr tmp;
86 if (!matchPattern(value: toCheck, pattern: m_Constant(bind_value: &tmp))) {
87 return rewriter.notifyMatchFailure(location,
88 "Non-const or non-dense input tensor");
89 }
90
91 // Make sure it actually is a TOSA constant (the match allows for other
92 // constants as well)
93 if (isa<ConstOp>(toCheck.getDefiningOp())) {
94 return success();
95 }
96
97 return rewriter.notifyMatchFailure(location,
98 "The reciprocal can only be folded if "
99 "it operates on a TOSA constant");
100}
101
102/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
103LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
104 TosaOp location,
105 PatternRewriter &rewriter) {
106 auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
107 if (failed(floatCheck)) {
108 return floatCheck;
109 }
110 return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
111}
112
113/// Heuristic to decide when to replace a unary operation on a constant with the
114/// folded value.
115/// Folding operations on constants can lead to an increased memory usage
116/// whenever the input cannot be replaced but a new constant is inserted. Hence,
117/// this will currently only suggest folding when the memory impact is
118/// negligible.
119/// Takes the \p unaryOp and the constant input \p values.
120/// \returns Whether folding should be applied.
121bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
122 assert(unaryOp->getNumOperands() == 1);
123 auto inputOp = unaryOp->getOperand(0);
124
125 // If the input is a splat, we don't care for the number of users
126 if (isa<SplatElementsAttr>(Val: values)) {
127 return true;
128 }
129
130 // If this is the only use of the tensor it should be replaced as no
131 // additional memory is required
132 return inputOp.hasOneUse();
133}
134
135template <typename RangeType>
136DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
137 ShapedType outputType,
138 llvm::ArrayRef<int64_t> permValues) {
139 using ElementType = std::decay_t<decltype(*std::begin(data))>;
140
141 assert(inputType.getElementType() == outputType.getElementType());
142
143 if (inputType.getNumElements() == 0)
144 return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{});
145
146 auto inputShape = inputType.getShape();
147
148 // The inverted permutation map and strides of the output are used to compute
149 // the contribution of a given dimension to the destination linear index in
150 // an order-independent way.
151 auto outputStrides = computeStrides(outputType.getShape());
152 auto invertedPermValues = invertPermutationVector(permutation: permValues);
153
154 auto initialValue = *std::begin(data);
155 SmallVector<ElementType> outputValues(inputType.getNumElements(),
156 initialValue);
157
158 for (const auto &it : llvm::enumerate(data)) {
159 auto srcLinearIndex = it.index();
160
161 uint64_t dstLinearIndex = 0;
162 for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
163 // Compute the index into the current dimension of the source vector.
164 auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
165 srcLinearIndex /= inputShape[dim];
166
167 // Add the contribution of the current dimension to the output using the
168 // permutation map.
169 dstLinearIndex +=
170 outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
171 }
172
173 outputValues[dstLinearIndex] = it.value();
174 }
175
176 return DenseElementsAttr::get(outputType,
177 llvm::ArrayRef<ElementType>(outputValues));
178}
179
180// A type specialized transposition of an ElementsAttr.
181// This implementation tries to operate on the underlying data in its raw
182// representation when possible to avoid allocating a large number of Attribute
183// objects.
184DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
185 ShapedType outputType,
186 llvm::ArrayRef<int64_t> permValues) {
187 if (auto data = attr.tryGetValues<bool>())
188 return transposeType(*data, inputType, outputType, permValues);
189
190 if (auto data = attr.tryGetValues<int8_t>())
191 return transposeType(*data, inputType, outputType, permValues);
192
193 if (auto data = attr.tryGetValues<int16_t>())
194 return transposeType(*data, inputType, outputType, permValues);
195
196 if (auto data = attr.tryGetValues<int32_t>())
197 return transposeType(*data, inputType, outputType, permValues);
198
199 if (auto data = attr.tryGetValues<int64_t>())
200 return transposeType(*data, inputType, outputType, permValues);
201
202 if (auto data = attr.tryGetValues<float>())
203 return transposeType(*data, inputType, outputType, permValues);
204
205 if (auto data = attr.tryGetValues<APFloat>())
206 return transposeType(*data, inputType, outputType, permValues);
207
208 return nullptr;
209}
210
211struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
212 using OpRewritePattern::OpRewritePattern;
213
214 LogicalResult matchAndRewrite(tosa::TransposeOp op,
215 PatternRewriter &rewriter) const override {
216 auto outputType = cast<ShapedType>(op.getType());
217 // TOSA supports quantized types.
218 if (!outputType.getElementType().isIntOrIndexOrFloat())
219 return failure();
220
221 ElementsAttr inputValues;
222 if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
223 return failure();
224 // Make sure the input is a constant that has a single user.
225 if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
226 return failure();
227
228 DenseIntElementsAttr permAttr;
229 if (!matchPattern(op.getPerms(), m_Constant(bind_value: &permAttr)))
230 return failure();
231 auto permValues = llvm::map_to_vector(
232 // TOSA allows both 32- and 64-bit integer tensors here.
233 permAttr.getValues<APInt>(),
234 [](const APInt &val) { return val.getSExtValue(); });
235
236 auto inputType = cast<ShapedType>(op.getInput1().getType());
237
238 auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
239 if (!resultAttr) {
240 return rewriter.notifyMatchFailure(
241 op, "unsupported attribute or element type");
242 }
243
244 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
245 return success();
246 }
247};
248
249struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
250
251 using OpRewritePattern::OpRewritePattern;
252
253 LogicalResult matchAndRewrite(ReciprocalOp recip,
254 PatternRewriter &rewriter) const override {
255 auto inputTensor = recip.getInput1();
256
257 // Check that we can apply folding
258 auto preCondCheck =
259 notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
260 if (failed(preCondCheck)) {
261 return preCondCheck;
262 }
263
264 // Extract the tensor values
265 DenseElementsAttr inputValues;
266 matchPattern(inputTensor, m_Constant(bind_value: &inputValues));
267
268 // Check whether this should be folded.
269 if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
270 return rewriter.notifyMatchFailure(
271 recip, "Currently, reciprocals will only be folded if the input "
272 "tensor has a single user");
273 }
274
275 // Create a new tensor with the updated values
276 auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
277 inputValues, &ReciprocalOp::calcOneElement,
278 cast<FloatType>(inputValues.getElementType()));
279
280 // Replace the use of the reciprocal with the transformed tensor
281 rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
282 return success();
283 }
284};
285
286/// Getting the axes position of the element which is located
287/// in the tensor at the counter index
288
289llvm::SmallVector<int64_t>
290getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
291 int64_t remaining = index;
292 llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
293 for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
294 position[i] = remaining % tensorShape[i];
295 remaining /= tensorShape[i];
296 }
297 return position;
298}
299
300/// Getting the index of the element which is located at the
301/// axes position in the tensor
302
303int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
304 llvm::ArrayRef<int64_t> tensorShape) {
305 int64_t index = 0;
306 int64_t multiplierTmp = 1;
307 for (int64_t i = position.size() - 1; i >= 0; --i) {
308 index += position[i] * multiplierTmp;
309 multiplierTmp *= tensorShape[i];
310 }
311 return index;
312}
313
314template <typename OperationType>
315llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
316 llvm::ArrayRef<int64_t> oldShape,
317 int64_t reductionAxis,
318 int64_t reductionIndex) {
319
320 llvm::SmallVector<int64_t> newShape(oldShape);
321 newShape[reductionAxis] = 1;
322 /// Let's calculate the position of the index
323 llvm::SmallVector<int64_t> position =
324 getPositionFromIndex(index: reductionIndex, tensorShape: newShape);
325 auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
326 /// Starting from the first positon along the reduction axis
327 position[reductionAxis] = 0;
328 int64_t indexAtOldTensor = getIndexFromPosition(position, tensorShape: oldShape);
329 llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
330
331 for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
332 ++reductionAxisVal) {
333
334 int64_t stride = std::accumulate(first: oldShape.begin() + reductionAxis + 1,
335 last: oldShape.end(), init: 1, binary_op: std::multiplies<int>());
336 int64_t index = indexAtOldTensor + stride * reductionAxisVal;
337 reducedValue =
338 OperationType::calcOneElement(reducedValue, oldTensor[index]);
339 }
340 return reducedValue;
341}
342
343template <typename OperationType>
344struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
345
346 ReduceConstantOptimization(MLIRContext *context,
347 bool aggressiveReduceConstant)
348 : OpRewritePattern<OperationType>(context),
349 aggressiveReduceConstant(aggressiveReduceConstant) {}
350
351 using OpRewritePattern<OperationType>::OpRewritePattern;
352
353 LogicalResult matchAndRewrite(OperationType op,
354 PatternRewriter &rewriter) const override {
355 Value inputOp = op.getInput();
356 auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
357
358 if (!constOp)
359 return rewriter.notifyMatchFailure(
360 op, "reduce input must be const operation");
361
362 if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
363 return rewriter.notifyMatchFailure(
364 op, "input operation has more than one user");
365
366 auto resultType = cast<ShapedType>(op.getOutput().getType());
367
368 if (!resultType.hasStaticShape())
369 return rewriter.notifyMatchFailure(op, "result type shape is not static");
370
371 auto reductionAxis = op.getAxis();
372 const auto denseElementsAttr = constOp.getValue();
373 const auto shapedOldElementsValues =
374 cast<ShapedType>(denseElementsAttr.getType());
375
376 if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
377 return rewriter.notifyMatchFailure(
378 op, "reduce input currently supported with integer type");
379
380 auto oldShape = shapedOldElementsValues.getShape();
381 auto newShape = resultType.getShape();
382
383 auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
384 std::multiplies<int>());
385 llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
386
387 for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
388 ++reductionIndex) {
389
390 /// Let's reduce all the elements along this reduction axis
391 newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
392 denseElementsAttr, oldShape, reductionAxis, reductionIndex);
393 }
394
395 auto rankedTensorType = cast<RankedTensorType>(resultType);
396 auto denseAttr =
397 mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
398 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
399 return success();
400 }
401 const bool aggressiveReduceConstant;
402};
403
404} // namespace
405
406void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
407 RewritePatternSet &patterns,
408 bool aggressiveReduceConstant) {
409 patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
410 ctx, aggressiveReduceConstant);
411 patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
412 ctx, aggressiveReduceConstant);
413 patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
414 ctx, aggressiveReduceConstant);
415 patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
416 ctx, aggressiveReduceConstant);
417 patterns.add<ReduceConstantOptimization<ReduceProdOp>>(
418 ctx, aggressiveReduceConstant);
419 patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
420 ctx, aggressiveReduceConstant);
421}
422
423void mlir::tosa::populateTosaFoldConstantTransposePatterns(
424 MLIRContext *ctx, RewritePatternSet &patterns) {
425 patterns.add<TosaFoldConstantTranspose>(arg&: ctx);
426}
427
428void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
429 MLIRContext *ctx, RewritePatternSet &patterns) {
430 patterns.add<TosaFoldConstantReciprocal>(arg&: ctx);
431}
432

source code of mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp