1//===- TosaFolders.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Fold TOSA operations
10//
11//===----------------------------------------------------------------------===//
12
13#include <functional>
14#include <numeric>
15
16#include "mlir/Dialect/Tosa/IR/TosaOps.h"
17#include "mlir/Dialect/Tosa/Transforms/Passes.h"
18#include "mlir/Dialect/Utils/IndexingUtils.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/DialectResourceBlobManager.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/Pass/Pass.h"
24#include "llvm/ADT/APFloat.h"
25#include "llvm/ADT/FloatingPointMode.h"
26#include "llvm/ADT/SmallVector.h"
27
28using namespace mlir;
29using namespace mlir::tosa;
30
31namespace {
32
33/// Apply the given transformation \p toApply to every element of the tensor to
34/// be transformed \p toTransform.
35///
36/// Elements of \p toTransform are extracted as \p SrcValueType.
37///
38/// \returns A tensor with the same size as \p toTransform, containing
39/// \p TargetValueType values of type \p TargetType.
40template <class SrcValType, class TargetValType, class TargetType>
41DenseElementsAttr applyElementWise(
42 const DenseElementsAttr &toTransform,
43 const std::function<TargetValType(const SrcValType &)> &toApply,
44 TargetType targetType) {
45 SmallVector<TargetValType> transformedValues;
46 // We already know the amount of values we will insert, reserve space for
47 // all of them to avoid dynamic resizing
48 transformedValues.reserve(toTransform.getNumElements());
49 for (auto val : toTransform.getValues<SrcValType>()) {
50 auto transformedVal = toApply(val);
51 transformedValues.push_back(transformedVal);
52 }
53
54 // Make sure that the output tensor has the expected output type
55 auto inShape = toTransform.getType();
56 auto outTy = inShape.cloneWith({}, targetType);
57
58 return DenseElementsAttr::get(outTy, transformedValues);
59}
60
61template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
62 const DenseElementsAttr &toTransform,
63 const std::function<APFloat(const APFloat &)> &toApply,
64 FloatType targetType);
65
66/// Function that checks if the type contained in \p toCheck is float.
67LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
68 PatternRewriter &rewriter) {
69 if (isa<FloatType>(Val: toCheck.getType().getElementType())) {
70 return success();
71 }
72 return rewriter.notifyMatchFailure(location,
73 "Unexpected input tensor type: the "
74 "TOSA spec only allows floats");
75}
76
77/// Function that checks if \p toCheck is a dense TOSA constant tensor.
78LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
79 TosaOp location,
80 PatternRewriter &rewriter) {
81 // Check whether the tensor is constant and dense
82 // TODO We currently ensure the tensor is dense by using the correct type for
83 // the bind_value, however we do not actually need this value. It would be
84 // nicer to only have a check here.
85 DenseElementsAttr tmp;
86 if (!matchPattern(value: toCheck, pattern: m_Constant(bind_value: &tmp))) {
87 return rewriter.notifyMatchFailure(location,
88 "Non-const or non-dense input tensor");
89 }
90
91 // Make sure it actually is a TOSA constant (the match allows for other
92 // constants as well)
93 if (isa<ConstOp>(toCheck.getDefiningOp())) {
94 return success();
95 }
96
97 return rewriter.notifyMatchFailure(location,
98 "The reciprocal can only be folded if "
99 "it operates on a TOSA constant");
100}
101
102/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
103LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
104 TosaOp location,
105 PatternRewriter &rewriter) {
106 auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
107 if (failed(floatCheck)) {
108 return floatCheck;
109 }
110 return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
111}
112
113/// Heuristic to decide when to replace a unary operation on a constant with the
114/// folded value.
115/// Folding operations on constants can lead to an increased memory usage
116/// whenever the input cannot be replaced but a new constant is inserted. Hence,
117/// this will currently only suggest folding when the memory impact is
118/// negligible.
119/// Takes the \p unaryOp and the constant input \p values.
120/// \returns Whether folding should be applied.
121bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
122 assert(unaryOp->getNumOperands() == 1);
123 auto inputOp = unaryOp->getOperand(0);
124
125 // If the input is a splat, we don't care for the number of users
126 if (isa<SplatElementsAttr>(Val: values)) {
127 return true;
128 }
129
130 // If this is the only use of the tensor it should be replaced as no
131 // additional memory is required
132 return inputOp.hasOneUse();
133}
134
135template <typename RangeType>
136DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
137 ShapedType outputType,
138 llvm::ArrayRef<int64_t> permValues) {
139 using ElementType = std::decay_t<decltype(*std::begin(data))>;
140
141 assert(inputType.getElementType() == outputType.getElementType());
142
143 if (inputType.getNumElements() == 0)
144 return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{});
145
146 auto inputShape = inputType.getShape();
147
148 // The inverted permutation map and strides of the output are used to compute
149 // the contribution of a given dimension to the destination linear index in
150 // an order-independent way.
151 auto outputStrides = computeStrides(outputType.getShape());
152 auto invertedPermValues = invertPermutationVector(permutation: permValues);
153
154 auto initialValue = *std::begin(data);
155 SmallVector<ElementType> outputValues(inputType.getNumElements(),
156 initialValue);
157
158 for (const auto &it : llvm::enumerate(data)) {
159 auto srcLinearIndex = it.index();
160
161 uint64_t dstLinearIndex = 0;
162 for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
163 // Compute the index into the current dimension of the source vector.
164 auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
165 srcLinearIndex /= inputShape[dim];
166
167 // Add the contribution of the current dimension to the output using the
168 // permutation map.
169 dstLinearIndex +=
170 outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
171 }
172
173 outputValues[dstLinearIndex] = it.value();
174 }
175
176 return DenseElementsAttr::get(outputType,
177 llvm::ArrayRef<ElementType>(outputValues));
178}
179
180// Try to get the values of a DenseResourceElementsAttr construct
181template <typename T>
182std::optional<ArrayRef<T>> tryGetDenseResourceValues(ElementsAttr attr) {
183 if (auto denseResource = dyn_cast<DenseResourceElementsAttr>(attr)) {
184 // Check that the resource memory blob exists
185 AsmResourceBlob *blob = denseResource.getRawHandle().getBlob();
186 if (!blob)
187 return std::nullopt;
188
189 // Check that the data are in a valid form
190 bool isSplat = false;
191 if (!DenseElementsAttr::isValidRawBuffer(attr.getShapedType(),
192 blob->getData(), isSplat)) {
193 return std::nullopt;
194 }
195
196 return blob->template getDataAs<T>();
197 }
198
199 return std::nullopt;
200}
201
202// A type specialized transposition of an ElementsAttr.
203// This implementation tries to operate on the underlying data in its raw
204// representation when possible to avoid allocating a large number of Attribute
205// objects.
206DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
207 ShapedType outputType,
208 llvm::ArrayRef<int64_t> permValues) {
209 // Handle generic ElementsAttr
210 if (auto data = attr.tryGetValues<bool>())
211 return transposeType(*data, inputType, outputType, permValues);
212
213 if (auto data = attr.tryGetValues<int8_t>())
214 return transposeType(*data, inputType, outputType, permValues);
215
216 if (auto data = attr.tryGetValues<int16_t>())
217 return transposeType(*data, inputType, outputType, permValues);
218
219 if (auto data = attr.tryGetValues<int32_t>())
220 return transposeType(*data, inputType, outputType, permValues);
221
222 if (auto data = attr.tryGetValues<int64_t>())
223 return transposeType(*data, inputType, outputType, permValues);
224
225 if (auto data = attr.tryGetValues<float>())
226 return transposeType(*data, inputType, outputType, permValues);
227
228 if (auto data = attr.tryGetValues<APFloat>())
229 return transposeType(*data, inputType, outputType, permValues);
230
231 // Handle DenseResourceElementsAttr
232 if (isa<DenseResourceElementsAttr>(attr)) {
233 auto elementTy = attr.getElementType();
234
235 if (auto data = tryGetDenseResourceValues<bool>(attr);
236 data && elementTy.isInteger(1))
237 return transposeType(*data, inputType, outputType, permValues);
238
239 if (auto data = tryGetDenseResourceValues<int8_t>(attr);
240 data && elementTy.isInteger(8))
241 return transposeType(*data, inputType, outputType, permValues);
242
243 if (auto data = tryGetDenseResourceValues<int16_t>(attr);
244 data && elementTy.isInteger(16))
245 return transposeType(*data, inputType, outputType, permValues);
246
247 if (auto data = tryGetDenseResourceValues<int32_t>(attr);
248 data && elementTy.isInteger(32))
249 return transposeType(*data, inputType, outputType, permValues);
250
251 if (auto data = tryGetDenseResourceValues<int64_t>(attr);
252 data && elementTy.isInteger(64))
253 return transposeType(*data, inputType, outputType, permValues);
254
255 if (auto data = tryGetDenseResourceValues<float>(attr);
256 data && elementTy.isF32())
257 return transposeType(*data, inputType, outputType, permValues);
258 }
259
260 return nullptr;
261}
262
263struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
264 using OpRewritePattern::OpRewritePattern;
265
266 LogicalResult matchAndRewrite(tosa::TransposeOp op,
267 PatternRewriter &rewriter) const override {
268 auto outputType = cast<ShapedType>(op.getType());
269 // TOSA supports quantized types.
270 if (!outputType.getElementType().isIntOrIndexOrFloat())
271 return failure();
272
273 ElementsAttr inputValues;
274 if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
275 return failure();
276 // Make sure the input is a constant that has a single user.
277 if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
278 return failure();
279
280 auto permValues = llvm::map_to_vector(
281 op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
282
283 auto inputType = cast<ShapedType>(op.getInput1().getType());
284
285 auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
286 if (!resultAttr) {
287 return rewriter.notifyMatchFailure(
288 op, "unsupported attribute or element type");
289 }
290
291 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
292 return success();
293 }
294};
295
296struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
297
298 using OpRewritePattern::OpRewritePattern;
299
300 LogicalResult matchAndRewrite(ReciprocalOp recip,
301 PatternRewriter &rewriter) const override {
302 auto inputTensor = recip.getInput1();
303
304 // Check that we can apply folding
305 auto preCondCheck =
306 notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
307 if (failed(preCondCheck)) {
308 return preCondCheck;
309 }
310
311 // Extract the tensor values
312 DenseElementsAttr inputValues;
313 matchPattern(inputTensor, m_Constant(bind_value: &inputValues));
314
315 // Check whether this should be folded.
316 if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
317 return rewriter.notifyMatchFailure(
318 recip, "Currently, reciprocals will only be folded if the input "
319 "tensor has a single user");
320 }
321
322 // Create a new tensor with the updated values
323 auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
324 inputValues, &ReciprocalOp::calcOneElement,
325 cast<FloatType>(inputValues.getElementType()));
326
327 // Replace the use of the reciprocal with the transformed tensor
328 rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
329 return success();
330 }
331};
332
333/// Getting the axes position of the element which is located
334/// in the tensor at the counter index
335
336llvm::SmallVector<int64_t>
337getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
338 int64_t remaining = index;
339 llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
340 for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
341 position[i] = remaining % tensorShape[i];
342 remaining /= tensorShape[i];
343 }
344 return position;
345}
346
347/// Getting the index of the element which is located at the
348/// axes position in the tensor
349
350int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
351 llvm::ArrayRef<int64_t> tensorShape) {
352 int64_t index = 0;
353 int64_t multiplierTmp = 1;
354 for (int64_t i = position.size() - 1; i >= 0; --i) {
355 index += position[i] * multiplierTmp;
356 multiplierTmp *= tensorShape[i];
357 }
358 return index;
359}
360
361template <typename OperationType>
362llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
363 llvm::ArrayRef<int64_t> oldShape,
364 int64_t reductionAxis,
365 int64_t reductionIndex) {
366
367 llvm::SmallVector<int64_t> newShape(oldShape);
368 newShape[reductionAxis] = 1;
369 /// Let's calculate the position of the index
370 llvm::SmallVector<int64_t> position =
371 getPositionFromIndex(index: reductionIndex, tensorShape: newShape);
372 auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
373 /// Starting from the first positon along the reduction axis
374 position[reductionAxis] = 0;
375 int64_t indexAtOldTensor = getIndexFromPosition(position, tensorShape: oldShape);
376 llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
377
378 for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
379 ++reductionAxisVal) {
380
381 int64_t stride = std::accumulate(first: oldShape.begin() + reductionAxis + 1,
382 last: oldShape.end(), init: 1, binary_op: std::multiplies<int>());
383 int64_t index = indexAtOldTensor + stride * reductionAxisVal;
384 reducedValue =
385 OperationType::calcOneElement(reducedValue, oldTensor[index]);
386 }
387 return reducedValue;
388}
389
390template <typename OperationType>
391struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
392
393 ReduceConstantOptimization(MLIRContext *context,
394 bool aggressiveReduceConstant)
395 : OpRewritePattern<OperationType>(context),
396 aggressiveReduceConstant(aggressiveReduceConstant) {}
397
398 using OpRewritePattern<OperationType>::OpRewritePattern;
399
400 LogicalResult matchAndRewrite(OperationType op,
401 PatternRewriter &rewriter) const override {
402 Value inputOp = op.getInput();
403 auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
404
405 if (!constOp)
406 return rewriter.notifyMatchFailure(
407 op, "reduce input must be const operation");
408
409 if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
410 return rewriter.notifyMatchFailure(
411 op, "input operation has more than one user");
412
413 auto resultType = cast<ShapedType>(op.getOutput().getType());
414
415 if (!resultType.hasStaticShape())
416 return rewriter.notifyMatchFailure(op, "result type shape is not static");
417
418 auto reductionAxis = op.getAxis();
419 const auto denseElementsAttr = constOp.getValues();
420 const auto shapedOldElementsValues =
421 cast<ShapedType>(denseElementsAttr.getType());
422
423 if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
424 return rewriter.notifyMatchFailure(
425 op, "reduce input currently supported with integer type");
426
427 auto oldShape = shapedOldElementsValues.getShape();
428 auto newShape = resultType.getShape();
429
430 auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
431 std::multiplies<int>());
432 llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
433
434 for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
435 ++reductionIndex) {
436
437 /// Let's reduce all the elements along this reduction axis
438 newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
439 denseElementsAttr, oldShape, reductionAxis, reductionIndex);
440 }
441
442 auto rankedTensorType = cast<RankedTensorType>(resultType);
443 auto denseAttr =
444 mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
445 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
446 return success();
447 }
448 const bool aggressiveReduceConstant;
449};
450
451} // namespace
452
453void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
454 RewritePatternSet &patterns,
455 bool aggressiveReduceConstant) {
456 patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
457 ctx, aggressiveReduceConstant);
458 patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
459 ctx, aggressiveReduceConstant);
460 patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
461 ctx, aggressiveReduceConstant);
462 patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
463 ctx, aggressiveReduceConstant);
464 patterns.add<ReduceConstantOptimization<ReduceProductOp>>(
465 ctx, aggressiveReduceConstant);
466 patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
467 ctx, aggressiveReduceConstant);
468}
469
470void mlir::tosa::populateTosaFoldConstantTransposePatterns(
471 MLIRContext *ctx, RewritePatternSet &patterns) {
472 patterns.add<TosaFoldConstantTranspose>(arg&: ctx);
473}
474
475void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
476 MLIRContext *ctx, RewritePatternSet &patterns) {
477 patterns.add<TosaFoldConstantReciprocal>(arg&: ctx);
478}
479

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp