1//===- TosaFolders.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Fold TOSA operations
10//
11//===----------------------------------------------------------------------===//
12
13#include <functional>
14#include <numeric>
15
16#include "mlir/Dialect/Tosa/IR/TosaOps.h"
17#include "mlir/Dialect/Tosa/Transforms/Passes.h"
18#include "mlir/Dialect/Utils/IndexingUtils.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/DialectResourceBlobManager.h"
22#include "mlir/IR/Matchers.h"
23#include "llvm/ADT/SmallVector.h"
24
25using namespace mlir;
26using namespace mlir::tosa;
27
28namespace {
29
30/// Apply the given transformation \p toApply to every element of the tensor to
31/// be transformed \p toTransform.
32///
33/// Elements of \p toTransform are extracted as \p SrcValueType.
34///
35/// \returns A tensor with the same size as \p toTransform, containing
36/// \p TargetValueType values of type \p TargetType.
37template <class SrcValType, class TargetValType, class TargetType>
38DenseElementsAttr applyElementWise(
39 const DenseElementsAttr &toTransform,
40 const std::function<TargetValType(const SrcValType &)> &toApply,
41 TargetType targetType) {
42 SmallVector<TargetValType> transformedValues;
43 // We already know the amount of values we will insert, reserve space for
44 // all of them to avoid dynamic resizing
45 transformedValues.reserve(toTransform.getNumElements());
46 for (auto val : toTransform.getValues<SrcValType>()) {
47 auto transformedVal = toApply(val);
48 transformedValues.push_back(transformedVal);
49 }
50
51 // Make sure that the output tensor has the expected output type
52 auto inShape = toTransform.getType();
53 auto outTy = inShape.cloneWith(shape: {}, elementType: targetType);
54
55 return DenseElementsAttr::get(outTy, transformedValues);
56}
57
58template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
59 const DenseElementsAttr &toTransform,
60 const std::function<APFloat(const APFloat &)> &toApply,
61 FloatType targetType);
62
63/// Function that checks if the type contained in \p toCheck is float.
64LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
65 PatternRewriter &rewriter) {
66 if (isa<FloatType>(Val: toCheck.getType().getElementType())) {
67 return success();
68 }
69 return rewriter.notifyMatchFailure(arg&: location,
70 msg: "Unexpected input tensor type: the "
71 "TOSA spec only allows floats");
72}
73
74/// Function that checks if \p toCheck is a dense TOSA constant tensor.
75LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
76 TosaOp location,
77 PatternRewriter &rewriter) {
78 // Check whether the tensor is constant and dense
79 // TODO We currently ensure the tensor is dense by using the correct type for
80 // the bind_value, however we do not actually need this value. It would be
81 // nicer to only have a check here.
82 DenseElementsAttr tmp;
83 if (!matchPattern(value: toCheck, pattern: m_Constant(bind_value: &tmp))) {
84 return rewriter.notifyMatchFailure(arg&: location,
85 msg: "Non-const or non-dense input tensor");
86 }
87
88 // Make sure it actually is a TOSA constant (the match allows for other
89 // constants as well)
90 if (isa<ConstOp>(Val: toCheck.getDefiningOp())) {
91 return success();
92 }
93
94 return rewriter.notifyMatchFailure(arg&: location,
95 msg: "The reciprocal can only be folded if "
96 "it operates on a TOSA constant");
97}
98
99/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
100LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
101 TosaOp location,
102 PatternRewriter &rewriter) {
103 auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
104 if (failed(Result: floatCheck)) {
105 return floatCheck;
106 }
107 return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
108}
109
110/// Heuristic to decide when to replace a unary operation on a constant with the
111/// folded value.
112/// Folding operations on constants can lead to an increased memory usage
113/// whenever the input cannot be replaced but a new constant is inserted. Hence,
114/// this will currently only suggest folding when the memory impact is
115/// negligible.
116/// Takes the \p unaryOp and the constant input \p values.
117/// \returns Whether folding should be applied.
118bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
119 assert(unaryOp->getNumOperands() == 1);
120 auto inputOp = unaryOp->getOperand(idx: 0);
121
122 // If the input is a splat, we don't care for the number of users
123 if (isa<SplatElementsAttr>(Val: values)) {
124 return true;
125 }
126
127 // If this is the only use of the tensor it should be replaced as no
128 // additional memory is required
129 return inputOp.hasOneUse();
130}
131
132template <typename RangeType>
133DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
134 ShapedType outputType,
135 llvm::ArrayRef<int64_t> permValues) {
136 using ElementType = std::decay_t<decltype(*std::begin(data))>;
137
138 assert(inputType.getElementType() == outputType.getElementType());
139
140 if (inputType.getNumElements() == 0)
141 return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{});
142
143 auto inputShape = inputType.getShape();
144
145 // The inverted permutation map and strides of the output are used to compute
146 // the contribution of a given dimension to the destination linear index in
147 // an order-independent way.
148 auto outputStrides = computeStrides(sizes: outputType.getShape());
149 auto invertedPermValues = invertPermutationVector(permutation: permValues);
150
151 auto initialValue = *std::begin(data);
152 SmallVector<ElementType> outputValues(inputType.getNumElements(),
153 initialValue);
154
155 for (const auto &it : llvm::enumerate(data)) {
156 auto srcLinearIndex = it.index();
157
158 uint64_t dstLinearIndex = 0;
159 for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
160 // Compute the index into the current dimension of the source vector.
161 auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
162 srcLinearIndex /= inputShape[dim];
163
164 // Add the contribution of the current dimension to the output using the
165 // permutation map.
166 dstLinearIndex +=
167 outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
168 }
169
170 outputValues[dstLinearIndex] = it.value();
171 }
172
173 return DenseElementsAttr::get(outputType,
174 llvm::ArrayRef<ElementType>(outputValues));
175}
176
177// Try to get the values of a DenseResourceElementsAttr construct
178template <typename T>
179std::optional<ArrayRef<T>> tryGetDenseResourceValues(ElementsAttr attr) {
180 if (auto denseResource = dyn_cast<DenseResourceElementsAttr>(Val&: attr)) {
181 // Check that the resource memory blob exists
182 AsmResourceBlob *blob = denseResource.getRawHandle().getBlob();
183 if (!blob)
184 return std::nullopt;
185
186 // Check that the data are in a valid form
187 bool isSplat = false;
188 if (!DenseElementsAttr::isValidRawBuffer(type: attr.getShapedType(),
189 rawBuffer: blob->getData(), detectedSplat&: isSplat)) {
190 return std::nullopt;
191 }
192
193 return blob->template getDataAs<T>();
194 }
195
196 return std::nullopt;
197}
198
199// A type specialized transposition of an ElementsAttr.
200// This implementation tries to operate on the underlying data in its raw
201// representation when possible to avoid allocating a large number of Attribute
202// objects.
203DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
204 ShapedType outputType,
205 llvm::ArrayRef<int64_t> permValues) {
206 // Handle generic ElementsAttr
207 if (auto data = attr.tryGetValues<bool>())
208 return transposeType(data: *data, inputType, outputType, permValues);
209
210 if (auto data = attr.tryGetValues<int8_t>())
211 return transposeType(data: *data, inputType, outputType, permValues);
212
213 if (auto data = attr.tryGetValues<int16_t>())
214 return transposeType(data: *data, inputType, outputType, permValues);
215
216 if (auto data = attr.tryGetValues<int32_t>())
217 return transposeType(data: *data, inputType, outputType, permValues);
218
219 if (auto data = attr.tryGetValues<int64_t>())
220 return transposeType(data: *data, inputType, outputType, permValues);
221
222 if (auto data = attr.tryGetValues<float>())
223 return transposeType(data: *data, inputType, outputType, permValues);
224
225 if (auto data = attr.tryGetValues<APFloat>())
226 return transposeType(data: *data, inputType, outputType, permValues);
227
228 // Handle DenseResourceElementsAttr
229 if (isa<DenseResourceElementsAttr>(Val: attr)) {
230 auto elementTy = attr.getElementType();
231
232 if (auto data = tryGetDenseResourceValues<bool>(attr);
233 data && elementTy.isInteger(width: 1))
234 return transposeType(data: *data, inputType, outputType, permValues);
235
236 if (auto data = tryGetDenseResourceValues<int8_t>(attr);
237 data && elementTy.isInteger(width: 8))
238 return transposeType(data: *data, inputType, outputType, permValues);
239
240 if (auto data = tryGetDenseResourceValues<int16_t>(attr);
241 data && elementTy.isInteger(width: 16))
242 return transposeType(data: *data, inputType, outputType, permValues);
243
244 if (auto data = tryGetDenseResourceValues<int32_t>(attr);
245 data && elementTy.isInteger(width: 32))
246 return transposeType(data: *data, inputType, outputType, permValues);
247
248 if (auto data = tryGetDenseResourceValues<int64_t>(attr);
249 data && elementTy.isInteger(width: 64))
250 return transposeType(data: *data, inputType, outputType, permValues);
251
252 if (auto data = tryGetDenseResourceValues<float>(attr);
253 data && elementTy.isF32())
254 return transposeType(data: *data, inputType, outputType, permValues);
255 }
256
257 return nullptr;
258}
259
260struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
261 using OpRewritePattern::OpRewritePattern;
262
263 LogicalResult matchAndRewrite(tosa::TransposeOp op,
264 PatternRewriter &rewriter) const override {
265 auto outputType = cast<ShapedType>(Val: op.getType());
266 // TOSA supports quantized types.
267 if (!outputType.getElementType().isIntOrIndexOrFloat())
268 return failure();
269
270 ElementsAttr inputValues;
271 if (!matchPattern(value: op.getInput1(), pattern: m_Constant(bind_value: &inputValues)))
272 return failure();
273 // Make sure the input is a constant that has a single user.
274 if (!llvm::hasSingleElement(C: op.getInput1().getDefiningOp()->getUsers()))
275 return failure();
276
277 auto permValues = llvm::map_to_vector(
278 C: op.getPerms(), F: [](const int32_t v) { return static_cast<int64_t>(v); });
279
280 auto inputType = cast<ShapedType>(Val: op.getInput1().getType());
281
282 auto resultAttr = transpose(attr: inputValues, inputType, outputType, permValues);
283 if (!resultAttr) {
284 return rewriter.notifyMatchFailure(
285 arg&: op, msg: "unsupported attribute or element type");
286 }
287
288 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, args&: outputType, args&: resultAttr);
289 return success();
290 }
291};
292
293struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
294
295 using OpRewritePattern::OpRewritePattern;
296
297 LogicalResult matchAndRewrite(ReciprocalOp recip,
298 PatternRewriter &rewriter) const override {
299 auto inputTensor = recip.getInput1();
300
301 // Check that we can apply folding
302 auto preCondCheck =
303 notifyIfNotConstantFloatTosaTensor(toCheck: inputTensor, location: recip, rewriter);
304 if (failed(Result: preCondCheck)) {
305 return preCondCheck;
306 }
307
308 // Extract the tensor values
309 DenseElementsAttr inputValues;
310 matchPattern(value: inputTensor, pattern: m_Constant(bind_value: &inputValues));
311
312 // Check whether this should be folded.
313 if (!constantUnaryOpShouldBeFolded(unaryOp: recip, values: inputValues)) {
314 return rewriter.notifyMatchFailure(
315 arg&: recip, msg: "Currently, reciprocals will only be folded if the input "
316 "tensor has a single user");
317 }
318
319 // Create a new tensor with the updated values
320 auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
321 toTransform: inputValues, toApply: &ReciprocalOp::calcOneElement,
322 targetType: cast<FloatType>(Val: inputValues.getElementType()));
323
324 // Replace the use of the reciprocal with the transformed tensor
325 rewriter.replaceOpWithNewOp<ConstOp>(op: recip, args: newTensor.getType(), args&: newTensor);
326 return success();
327 }
328};
329
330/// Getting the axes position of the element which is located
331/// in the tensor at the counter index
332
333llvm::SmallVector<int64_t>
334getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
335 int64_t remaining = index;
336 llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
337 for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
338 position[i] = remaining % tensorShape[i];
339 remaining /= tensorShape[i];
340 }
341 return position;
342}
343
344/// Getting the index of the element which is located at the
345/// axes position in the tensor
346
347int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
348 llvm::ArrayRef<int64_t> tensorShape) {
349 int64_t index = 0;
350 int64_t multiplierTmp = 1;
351 for (int64_t i = position.size() - 1; i >= 0; --i) {
352 index += position[i] * multiplierTmp;
353 multiplierTmp *= tensorShape[i];
354 }
355 return index;
356}
357
358template <typename OperationType>
359llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
360 llvm::ArrayRef<int64_t> oldShape,
361 int64_t reductionAxis,
362 int64_t reductionIndex) {
363
364 llvm::SmallVector<int64_t> newShape(oldShape);
365 newShape[reductionAxis] = 1;
366 /// Let's calculate the position of the index
367 llvm::SmallVector<int64_t> position =
368 getPositionFromIndex(index: reductionIndex, tensorShape: newShape);
369 auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
370 /// Starting from the first positon along the reduction axis
371 position[reductionAxis] = 0;
372 int64_t indexAtOldTensor = getIndexFromPosition(position, tensorShape: oldShape);
373 llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
374
375 for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
376 ++reductionAxisVal) {
377
378 int64_t stride = std::accumulate(first: oldShape.begin() + reductionAxis + 1,
379 last: oldShape.end(), init: 1, binary_op: std::multiplies<int>());
380 int64_t index = indexAtOldTensor + stride * reductionAxisVal;
381 reducedValue =
382 OperationType::calcOneElement(reducedValue, oldTensor[index]);
383 }
384 return reducedValue;
385}
386
387template <typename OperationType>
388struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
389
390 ReduceConstantOptimization(MLIRContext *context,
391 bool aggressiveReduceConstant)
392 : OpRewritePattern<OperationType>(context),
393 aggressiveReduceConstant(aggressiveReduceConstant) {}
394
395 using OpRewritePattern<OperationType>::OpRewritePattern;
396
397 LogicalResult matchAndRewrite(OperationType op,
398 PatternRewriter &rewriter) const override {
399 Value inputOp = op.getInput();
400 auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
401
402 if (!constOp)
403 return rewriter.notifyMatchFailure(
404 op, "reduce input must be const operation");
405
406 if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
407 return rewriter.notifyMatchFailure(
408 op, "input operation has more than one user");
409
410 auto resultType = cast<ShapedType>(op.getOutput().getType());
411
412 if (!resultType.hasStaticShape())
413 return rewriter.notifyMatchFailure(op, "result type shape is not static");
414
415 auto reductionAxis = op.getAxis();
416 const auto denseElementsAttr = constOp.getValues();
417 const auto shapedOldElementsValues =
418 cast<ShapedType>(Val: denseElementsAttr.getType());
419
420 if (!llvm::isa<IntegerType>(Val: shapedOldElementsValues.getElementType()))
421 return rewriter.notifyMatchFailure(
422 op, "reduce input currently supported with integer type");
423
424 auto oldShape = shapedOldElementsValues.getShape();
425 auto newShape = resultType.getShape();
426
427 auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
428 std::multiplies<int>());
429 llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
430
431 for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
432 ++reductionIndex) {
433
434 /// Let's reduce all the elements along this reduction axis
435 newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
436 denseElementsAttr, oldShape, reductionAxis, reductionIndex);
437 }
438
439 auto rankedTensorType = cast<RankedTensorType>(resultType);
440 auto denseAttr =
441 mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
442 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
443 return success();
444 }
445 const bool aggressiveReduceConstant;
446};
447
448} // namespace
449
450void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
451 RewritePatternSet &patterns,
452 bool aggressiveReduceConstant) {
453 patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
454 arg&: ctx, args&: aggressiveReduceConstant);
455 patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
456 arg&: ctx, args&: aggressiveReduceConstant);
457 patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
458 arg&: ctx, args&: aggressiveReduceConstant);
459 patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
460 arg&: ctx, args&: aggressiveReduceConstant);
461 patterns.add<ReduceConstantOptimization<ReduceProductOp>>(
462 arg&: ctx, args&: aggressiveReduceConstant);
463 patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
464 arg&: ctx, args&: aggressiveReduceConstant);
465}
466
467void mlir::tosa::populateTosaFoldConstantTransposePatterns(
468 MLIRContext *ctx, RewritePatternSet &patterns) {
469 patterns.add<TosaFoldConstantTranspose>(arg&: ctx);
470}
471
472void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
473 MLIRContext *ctx, RewritePatternSet &patterns) {
474 patterns.add<TosaFoldConstantReciprocal>(arg&: ctx);
475}
476

source code of mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp