1 | //===- TosaFolders.cpp ----------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Fold TOSA operations |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include <functional> |
14 | #include <numeric> |
15 | |
16 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
17 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
18 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/BuiltinTypes.h" |
21 | #include "mlir/IR/Matchers.h" |
22 | #include "mlir/Pass/Pass.h" |
23 | #include "mlir/Support/LogicalResult.h" |
24 | #include "llvm/ADT/APFloat.h" |
25 | #include "llvm/ADT/FloatingPointMode.h" |
26 | #include "llvm/ADT/SmallVector.h" |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::tosa; |
30 | |
31 | namespace { |
32 | |
33 | /// Apply the given transformation \p toApply to every element of the tensor to |
34 | /// be transformed \p toTransform. |
35 | /// |
36 | /// Elements of \p toTransform are extracted as \p SrcValueType. |
37 | /// |
38 | /// \returns A tensor with the same size as \p toTransform, containing |
39 | /// \p TargetValueType values of type \p TargetType. |
40 | template <class SrcValType, class TargetValType, class TargetType> |
41 | DenseElementsAttr applyElementWise( |
42 | const DenseElementsAttr &toTransform, |
43 | const std::function<TargetValType(const SrcValType &)> &toApply, |
44 | TargetType targetType) { |
45 | SmallVector<TargetValType> transformedValues; |
46 | // We already know the amount of values we will insert, reserve space for |
47 | // all of them to avoid dynamic resizing |
48 | transformedValues.reserve(toTransform.getNumElements()); |
49 | for (auto val : toTransform.getValues<SrcValType>()) { |
50 | auto transformedVal = toApply(val); |
51 | transformedValues.push_back(transformedVal); |
52 | } |
53 | |
54 | // Make sure that the output tensor has the expected output type |
55 | auto inShape = toTransform.getType(); |
56 | auto outTy = inShape.cloneWith({}, targetType); |
57 | |
58 | return DenseElementsAttr::get(outTy, transformedValues); |
59 | } |
60 | |
61 | template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>( |
62 | const DenseElementsAttr &toTransform, |
63 | const std::function<APFloat(const APFloat &)> &toApply, |
64 | FloatType targetType); |
65 | |
66 | /// Function that checks if the type contained in \p toCheck is float. |
67 | LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location, |
68 | PatternRewriter &rewriter) { |
69 | if (isa<FloatType>(Val: toCheck.getType().getElementType())) { |
70 | return success(); |
71 | } |
72 | return rewriter.notifyMatchFailure(location, |
73 | "Unexpected input tensor type: the " |
74 | "TOSA spec only allows floats" ); |
75 | } |
76 | |
77 | /// Function that checks if \p toCheck is a dense TOSA constant tensor. |
78 | LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck, |
79 | TosaOp location, |
80 | PatternRewriter &rewriter) { |
81 | // Check whether the tensor is constant and dense |
82 | // TODO We currently ensure the tensor is dense by using the correct type for |
83 | // the bind_value, however we do not actually need this value. It would be |
84 | // nicer to only have a check here. |
85 | DenseElementsAttr tmp; |
86 | if (!matchPattern(value: toCheck, pattern: m_Constant(bind_value: &tmp))) { |
87 | return rewriter.notifyMatchFailure(location, |
88 | "Non-const or non-dense input tensor" ); |
89 | } |
90 | |
91 | // Make sure it actually is a TOSA constant (the match allows for other |
92 | // constants as well) |
93 | if (isa<ConstOp>(toCheck.getDefiningOp())) { |
94 | return success(); |
95 | } |
96 | |
97 | return rewriter.notifyMatchFailure(location, |
98 | "The reciprocal can only be folded if " |
99 | "it operates on a TOSA constant" ); |
100 | } |
101 | |
102 | /// Function that checks if \p toCheck is a dense TOSA constant float tensor. |
103 | LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck, |
104 | TosaOp location, |
105 | PatternRewriter &rewriter) { |
106 | auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter); |
107 | if (failed(floatCheck)) { |
108 | return floatCheck; |
109 | } |
110 | return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter); |
111 | } |
112 | |
113 | /// Heuristic to decide when to replace a unary operation on a constant with the |
114 | /// folded value. |
115 | /// Folding operations on constants can lead to an increased memory usage |
116 | /// whenever the input cannot be replaced but a new constant is inserted. Hence, |
117 | /// this will currently only suggest folding when the memory impact is |
118 | /// negligible. |
119 | /// Takes the \p unaryOp and the constant input \p values. |
120 | /// \returns Whether folding should be applied. |
121 | bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) { |
122 | assert(unaryOp->getNumOperands() == 1); |
123 | auto inputOp = unaryOp->getOperand(0); |
124 | |
125 | // If the input is a splat, we don't care for the number of users |
126 | if (isa<SplatElementsAttr>(Val: values)) { |
127 | return true; |
128 | } |
129 | |
130 | // If this is the only use of the tensor it should be replaced as no |
131 | // additional memory is required |
132 | return inputOp.hasOneUse(); |
133 | } |
134 | |
135 | template <typename RangeType> |
136 | DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType, |
137 | ShapedType outputType, |
138 | llvm::ArrayRef<int64_t> permValues) { |
139 | using ElementType = std::decay_t<decltype(*std::begin(data))>; |
140 | |
141 | assert(inputType.getElementType() == outputType.getElementType()); |
142 | |
143 | if (inputType.getNumElements() == 0) |
144 | return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{}); |
145 | |
146 | auto inputShape = inputType.getShape(); |
147 | |
148 | // The inverted permutation map and strides of the output are used to compute |
149 | // the contribution of a given dimension to the destination linear index in |
150 | // an order-independent way. |
151 | auto outputStrides = computeStrides(outputType.getShape()); |
152 | auto invertedPermValues = invertPermutationVector(permutation: permValues); |
153 | |
154 | auto initialValue = *std::begin(data); |
155 | SmallVector<ElementType> outputValues(inputType.getNumElements(), |
156 | initialValue); |
157 | |
158 | for (const auto &it : llvm::enumerate(data)) { |
159 | auto srcLinearIndex = it.index(); |
160 | |
161 | uint64_t dstLinearIndex = 0; |
162 | for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) { |
163 | // Compute the index into the current dimension of the source vector. |
164 | auto sourceIndexForDim = srcLinearIndex % inputShape[dim]; |
165 | srcLinearIndex /= inputShape[dim]; |
166 | |
167 | // Add the contribution of the current dimension to the output using the |
168 | // permutation map. |
169 | dstLinearIndex += |
170 | outputStrides[invertedPermValues[dim]] * sourceIndexForDim; |
171 | } |
172 | |
173 | outputValues[dstLinearIndex] = it.value(); |
174 | } |
175 | |
176 | return DenseElementsAttr::get(outputType, |
177 | llvm::ArrayRef<ElementType>(outputValues)); |
178 | } |
179 | |
180 | // A type specialized transposition of an ElementsAttr. |
181 | // This implementation tries to operate on the underlying data in its raw |
182 | // representation when possible to avoid allocating a large number of Attribute |
183 | // objects. |
184 | DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType, |
185 | ShapedType outputType, |
186 | llvm::ArrayRef<int64_t> permValues) { |
187 | if (auto data = attr.tryGetValues<bool>()) |
188 | return transposeType(*data, inputType, outputType, permValues); |
189 | |
190 | if (auto data = attr.tryGetValues<int8_t>()) |
191 | return transposeType(*data, inputType, outputType, permValues); |
192 | |
193 | if (auto data = attr.tryGetValues<int16_t>()) |
194 | return transposeType(*data, inputType, outputType, permValues); |
195 | |
196 | if (auto data = attr.tryGetValues<int32_t>()) |
197 | return transposeType(*data, inputType, outputType, permValues); |
198 | |
199 | if (auto data = attr.tryGetValues<int64_t>()) |
200 | return transposeType(*data, inputType, outputType, permValues); |
201 | |
202 | if (auto data = attr.tryGetValues<float>()) |
203 | return transposeType(*data, inputType, outputType, permValues); |
204 | |
205 | if (auto data = attr.tryGetValues<APFloat>()) |
206 | return transposeType(*data, inputType, outputType, permValues); |
207 | |
208 | return nullptr; |
209 | } |
210 | |
211 | struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> { |
212 | using OpRewritePattern::OpRewritePattern; |
213 | |
214 | LogicalResult matchAndRewrite(tosa::TransposeOp op, |
215 | PatternRewriter &rewriter) const override { |
216 | auto outputType = cast<ShapedType>(op.getType()); |
217 | // TOSA supports quantized types. |
218 | if (!outputType.getElementType().isIntOrIndexOrFloat()) |
219 | return failure(); |
220 | |
221 | ElementsAttr inputValues; |
222 | if (!matchPattern(op.getInput1(), m_Constant(&inputValues))) |
223 | return failure(); |
224 | // Make sure the input is a constant that has a single user. |
225 | if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers())) |
226 | return failure(); |
227 | |
228 | DenseIntElementsAttr permAttr; |
229 | if (!matchPattern(op.getPerms(), m_Constant(bind_value: &permAttr))) |
230 | return failure(); |
231 | auto permValues = llvm::map_to_vector( |
232 | // TOSA allows both 32- and 64-bit integer tensors here. |
233 | permAttr.getValues<APInt>(), |
234 | [](const APInt &val) { return val.getSExtValue(); }); |
235 | |
236 | auto inputType = cast<ShapedType>(op.getInput1().getType()); |
237 | |
238 | auto resultAttr = transpose(inputValues, inputType, outputType, permValues); |
239 | if (!resultAttr) { |
240 | return rewriter.notifyMatchFailure( |
241 | op, "unsupported attribute or element type" ); |
242 | } |
243 | |
244 | rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr); |
245 | return success(); |
246 | } |
247 | }; |
248 | |
249 | struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> { |
250 | |
251 | using OpRewritePattern::OpRewritePattern; |
252 | |
253 | LogicalResult matchAndRewrite(ReciprocalOp recip, |
254 | PatternRewriter &rewriter) const override { |
255 | auto inputTensor = recip.getInput1(); |
256 | |
257 | // Check that we can apply folding |
258 | auto preCondCheck = |
259 | notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter); |
260 | if (failed(preCondCheck)) { |
261 | return preCondCheck; |
262 | } |
263 | |
264 | // Extract the tensor values |
265 | DenseElementsAttr inputValues; |
266 | matchPattern(inputTensor, m_Constant(bind_value: &inputValues)); |
267 | |
268 | // Check whether this should be folded. |
269 | if (!constantUnaryOpShouldBeFolded(recip, inputValues)) { |
270 | return rewriter.notifyMatchFailure( |
271 | recip, "Currently, reciprocals will only be folded if the input " |
272 | "tensor has a single user" ); |
273 | } |
274 | |
275 | // Create a new tensor with the updated values |
276 | auto newTensor = applyElementWise<APFloat, APFloat, FloatType>( |
277 | inputValues, &ReciprocalOp::calcOneElement, |
278 | cast<FloatType>(inputValues.getElementType())); |
279 | |
280 | // Replace the use of the reciprocal with the transformed tensor |
281 | rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor); |
282 | return success(); |
283 | } |
284 | }; |
285 | |
286 | /// Getting the axes position of the element which is located |
287 | /// in the tensor at the counter index |
288 | |
289 | llvm::SmallVector<int64_t> |
290 | getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) { |
291 | int64_t remaining = index; |
292 | llvm::SmallVector<int64_t> position(tensorShape.size(), 0); |
293 | for (int64_t i = tensorShape.size() - 1; i >= 0; --i) { |
294 | position[i] = remaining % tensorShape[i]; |
295 | remaining /= tensorShape[i]; |
296 | } |
297 | return position; |
298 | } |
299 | |
300 | /// Getting the index of the element which is located at the |
301 | /// axes position in the tensor |
302 | |
303 | int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position, |
304 | llvm::ArrayRef<int64_t> tensorShape) { |
305 | int64_t index = 0; |
306 | int64_t multiplierTmp = 1; |
307 | for (int64_t i = position.size() - 1; i >= 0; --i) { |
308 | index += position[i] * multiplierTmp; |
309 | multiplierTmp *= tensorShape[i]; |
310 | } |
311 | return index; |
312 | } |
313 | |
314 | template <typename OperationType> |
315 | llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr, |
316 | llvm::ArrayRef<int64_t> oldShape, |
317 | int64_t reductionAxis, |
318 | int64_t reductionIndex) { |
319 | |
320 | llvm::SmallVector<int64_t> newShape(oldShape); |
321 | newShape[reductionAxis] = 1; |
322 | /// Let's calculate the position of the index |
323 | llvm::SmallVector<int64_t> position = |
324 | getPositionFromIndex(index: reductionIndex, tensorShape: newShape); |
325 | auto oldTensor = oldTensorAttr.getValues<llvm::APInt>(); |
326 | /// Starting from the first positon along the reduction axis |
327 | position[reductionAxis] = 0; |
328 | int64_t indexAtOldTensor = getIndexFromPosition(position, tensorShape: oldShape); |
329 | llvm::APInt reducedValue = oldTensor[indexAtOldTensor]; |
330 | |
331 | for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis]; |
332 | ++reductionAxisVal) { |
333 | |
334 | int64_t stride = std::accumulate(first: oldShape.begin() + reductionAxis + 1, |
335 | last: oldShape.end(), init: 1, binary_op: std::multiplies<int>()); |
336 | int64_t index = indexAtOldTensor + stride * reductionAxisVal; |
337 | reducedValue = |
338 | OperationType::calcOneElement(reducedValue, oldTensor[index]); |
339 | } |
340 | return reducedValue; |
341 | } |
342 | |
343 | template <typename OperationType> |
344 | struct ReduceConstantOptimization : public OpRewritePattern<OperationType> { |
345 | |
346 | ReduceConstantOptimization(MLIRContext *context, |
347 | bool aggressiveReduceConstant) |
348 | : OpRewritePattern<OperationType>(context), |
349 | aggressiveReduceConstant(aggressiveReduceConstant) {} |
350 | |
351 | using OpRewritePattern<OperationType>::OpRewritePattern; |
352 | |
353 | LogicalResult matchAndRewrite(OperationType op, |
354 | PatternRewriter &rewriter) const override { |
355 | Value inputOp = op.getInput(); |
356 | auto constOp = inputOp.getDefiningOp<tosa::ConstOp>(); |
357 | |
358 | if (!constOp) |
359 | return rewriter.notifyMatchFailure( |
360 | op, "reduce input must be const operation" ); |
361 | |
362 | if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant) |
363 | return rewriter.notifyMatchFailure( |
364 | op, "input operation has more than one user" ); |
365 | |
366 | auto resultType = cast<ShapedType>(op.getOutput().getType()); |
367 | |
368 | if (!resultType.hasStaticShape()) |
369 | return rewriter.notifyMatchFailure(op, "result type shape is not static" ); |
370 | |
371 | auto reductionAxis = op.getAxis(); |
372 | const auto denseElementsAttr = constOp.getValue(); |
373 | const auto shapedOldElementsValues = |
374 | cast<ShapedType>(denseElementsAttr.getType()); |
375 | |
376 | if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType())) |
377 | return rewriter.notifyMatchFailure( |
378 | op, "reduce input currently supported with integer type" ); |
379 | |
380 | auto oldShape = shapedOldElementsValues.getShape(); |
381 | auto newShape = resultType.getShape(); |
382 | |
383 | auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1, |
384 | std::multiplies<int>()); |
385 | llvm::SmallVector<APInt> newReducedTensor(newNumOfElements); |
386 | |
387 | for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements; |
388 | ++reductionIndex) { |
389 | |
390 | /// Let's reduce all the elements along this reduction axis |
391 | newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>( |
392 | denseElementsAttr, oldShape, reductionAxis, reductionIndex); |
393 | } |
394 | |
395 | auto rankedTensorType = cast<RankedTensorType>(resultType); |
396 | auto denseAttr = |
397 | mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor); |
398 | rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr); |
399 | return success(); |
400 | } |
401 | const bool aggressiveReduceConstant; |
402 | }; |
403 | |
404 | } // namespace |
405 | |
406 | void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx, |
407 | RewritePatternSet &patterns, |
408 | bool aggressiveReduceConstant) { |
409 | patterns.add<ReduceConstantOptimization<ReduceAllOp>>( |
410 | ctx, aggressiveReduceConstant); |
411 | patterns.add<ReduceConstantOptimization<ReduceAnyOp>>( |
412 | ctx, aggressiveReduceConstant); |
413 | patterns.add<ReduceConstantOptimization<ReduceMaxOp>>( |
414 | ctx, aggressiveReduceConstant); |
415 | patterns.add<ReduceConstantOptimization<ReduceMinOp>>( |
416 | ctx, aggressiveReduceConstant); |
417 | patterns.add<ReduceConstantOptimization<ReduceProdOp>>( |
418 | ctx, aggressiveReduceConstant); |
419 | patterns.add<ReduceConstantOptimization<ReduceSumOp>>( |
420 | ctx, aggressiveReduceConstant); |
421 | } |
422 | |
423 | void mlir::tosa::populateTosaFoldConstantTransposePatterns( |
424 | MLIRContext *ctx, RewritePatternSet &patterns) { |
425 | patterns.add<TosaFoldConstantTranspose>(arg&: ctx); |
426 | } |
427 | |
428 | void mlir::tosa::populateTosaFoldConstantReciprocalPatterns( |
429 | MLIRContext *ctx, RewritePatternSet &patterns) { |
430 | patterns.add<TosaFoldConstantReciprocal>(arg&: ctx); |
431 | } |
432 | |