1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Quant/IR/Quant.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/Dialect/Tosa/IR/TosaOps.h"
17#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
18#include "mlir/IR/BuiltinTypeInterfaces.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Transforms/FoldUtils.h"
23#include "mlir/Transforms/InliningUtils.h"
24#include "llvm/ADT/APFloat.h"
25#include "llvm/ADT/APInt.h"
26
27#include <functional>
28
29using namespace mlir;
30using namespace mlir::tosa;
31
32//===----------------------------------------------------------------------===//
33// Operator Canonicalizers.
34//===----------------------------------------------------------------------===//
35
36//===----------------------------------------------------------------------===//
37// Tensor Data Engine Operators.
38//===----------------------------------------------------------------------===//
39
40// Check that the zero point of the tensor and padding operations are aligned.
41bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
42 // Check that padConst is a constant value and a scalar tensor
43 DenseElementsAttr padConstAttr;
44 if (!matchPattern(value: padConst, pattern: m_Constant(bind_value: &padConstAttr)) ||
45 (padConstAttr.size() != 1)) {
46 return false;
47 }
48
49 // Check that floating point pad is zero
50 if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(Val&: padConstAttr)) {
51 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
52 return padConstVal == 0.0f;
53 }
54
55 // Check that the zp and padConst align for the integer (quantized) case
56 if (auto padConstIntAttr =
57 mlir::dyn_cast<DenseIntElementsAttr>(Val&: padConstAttr)) {
58 DenseIntElementsAttr zpAttr;
59 // Check that zp is a constant value and a scalar tensor
60 if (!matchPattern(value: zp, pattern: m_Constant(bind_value: &zpAttr)) || (padConstAttr.size() != 1)) {
61 return false;
62 }
63
64 // Check equality
65 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
66 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
67 return zpVal == padConstVal;
68 }
69
70 // Bail-out on unsupported type
71 return false;
72}
73
74namespace {
75template <typename OpTy>
76struct PoolPadFoldAdaptor;
77
78template <>
79struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
80 using OpTy = tosa::AvgPool2dOp;
81 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
82 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
83 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
84 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
85 return false;
86 return true;
87 }
88 static bool checkPadConstCompliance(OpTy op, Value padConst) {
89 return checkMatchingPadConstAndZp(padConst, zp: op.getInputZp());
90 }
91 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
92 Value padInput, ArrayRef<int64_t> newPad) {
93 rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
94 op, args: op.getType(), args&: padInput, args: op.getInputZp(), args: op.getOutputZp(),
95 args: op.getKernel(), args: op.getStride(), args: rewriter.getDenseI64ArrayAttr(values: newPad),
96 args: op.getAccType());
97 }
98};
99
100template <>
101struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
102 using OpTy = tosa::MaxPool2dOp;
103 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
104 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
105 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
106 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
107 return false;
108 return true;
109 }
110 static bool checkPadConstCompliance(OpTy, Value padConst) {
111 // Check that padConst is a constant value and a scalar tensor
112 DenseElementsAttr padConstAttr;
113 if (!matchPattern(value: padConst, pattern: m_Constant(bind_value: &padConstAttr)) ||
114 padConstAttr.size() != 1) {
115 return false;
116 }
117
118 // Pad needs to be in the minimum value to be able to merge
119 if (auto padConstFpAttr =
120 mlir::dyn_cast<DenseFPElementsAttr>(Val&: padConstAttr)) {
121 const APFloat padConstVal = *padConstFpAttr.begin();
122 const APFloat lowestVal =
123 APFloat::getLargest(Sem: padConstVal.getSemantics(), Negative: true);
124 return padConstVal == lowestVal;
125 } else if (auto padConstIntAttr =
126 mlir::dyn_cast<DenseIntElementsAttr>(Val&: padConstAttr)) {
127 const APInt padConstVal = *padConstIntAttr.begin();
128 const unsigned int bitWidth = padConstVal.getBitWidth();
129 const APInt lowestVal =
130 padConstIntAttr.getElementType().isUnsignedInteger()
131 ? APInt::getZero(numBits: bitWidth)
132 : APInt::getSignedMinValue(numBits: bitWidth);
133 return padConstVal == lowestVal;
134 }
135
136 // Bail-out on unsupported type
137 return false;
138 }
139 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
140 Value padInput, ArrayRef<int64_t> newPad) {
141 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
142 op, args: op.getType(), args&: padInput, args: op.getKernel(), args: op.getStride(),
143 args: rewriter.getDenseI64ArrayAttr(values: newPad), args: op.getNanMode());
144 }
145};
146
147template <typename OpTy>
148struct ConvPadFoldAdaptor {
149 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
150 return true;
151 }
152 static bool checkPadConstCompliance(OpTy op, Value padConst) {
153 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
154 }
155 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
156 Value padInput, ArrayRef<int64_t> newPad) {
157 rewriter.replaceOpWithNewOp<OpTy>(
158 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
159 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
160 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
161 }
162};
163
164// Pattern attempts to fold a `tosa.pad` operator to a following tensor
165// operation like `tosa.conv2d` by merging the padding associated with the
166// pad operator directly to the implicit padding of the tensor operation.
167// This helps eliminate the explicit padding operator if unused.
168template <typename OpTy, typename AdaptorTy>
169struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
170 using OpRewritePattern<OpTy>::OpRewritePattern;
171
172 LogicalResult matchAndRewrite(OpTy tensorOp,
173 PatternRewriter &rewriter) const override {
174 // Check producer is a tosa::PadOp
175 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
176 if (!padOp)
177 return rewriter.notifyMatchFailure(tensorOp,
178 "Producer must be a tosa::PadOp.");
179
180 // Validate that tensor operation has sane padding
181 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
182 if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
183 return rewriter.notifyMatchFailure(
184 tensorOp, "Tensor operation padding shall have 4 elements.");
185
186 // Validate tosa::PadOp padding
187 DenseIntElementsAttr padOpPadding;
188 if (!matchPattern(padOp.getPadding(), m_Constant(bind_value: &padOpPadding))) {
189 return rewriter.notifyMatchFailure(
190 tensorOp,
191 "The `padding` input specified on the tosa::PadOp must be constant.");
192 }
193 // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
194 // C_after
195 if (padOpPadding.size() != 8)
196 return rewriter.notifyMatchFailure(tensorOp,
197 "Pad padding should have 8 elements.");
198 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
199 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
200 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
201 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
202 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
203 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
204 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
205 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
206
207 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
208 return rewriter.notifyMatchFailure(
209 tensorOp, "Folding padding in N or C dimensions is not supported.");
210
211 // Fold padding from Pad into the tensor operation
212 // 4 elements - pad_top, pad_bottom, pad_left, pad_right
213 SmallVector<int64_t> foldedPad(tensorOpPad.size());
214 foldedPad[0] = padHBefore + tensorOpPad[0];
215 foldedPad[1] = padHAfter + tensorOpPad[1];
216 foldedPad[2] = padWBefore + tensorOpPad[2];
217 foldedPad[3] = padWAfter + tensorOpPad[3];
218
219 // Check kernel related restrictions
220 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
221 return rewriter.notifyMatchFailure(
222 tensorOp, "Padding size not aligned with kernel restrictions.");
223 }
224
225 // Check padding constant restrictions
226 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
227 return rewriter.notifyMatchFailure(
228 tensorOp,
229 "Padding constant is not aligned with operator zero-point.");
230 }
231
232 // Check that padding doesn't grow more than 8K level (8192) for now
233 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
234 return rewriter.notifyMatchFailure(
235 tensorOp, "Padding size more than the 8K level limit.");
236 }
237
238 // Create operator
239 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
240 foldedPad);
241
242 return success();
243 }
244};
245} // namespace
246
247void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
248 MLIRContext *context) {
249 results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
250 PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
251 arg&: context);
252}
253
254void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
255 MLIRContext *context) {
256 results.add<
257 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
258 arg&: context);
259}
260
261void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
262 MLIRContext *context) {
263 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
264 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
265 arg&: context);
266}
267
268struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
269 using OpRewritePattern::OpRewritePattern;
270
271 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
272 PatternRewriter &rewriter) const override {
273 Value input = op.getInput();
274 Value output = op.getOutput();
275 ShapedType inputType = llvm::cast<ShapedType>(Val: input.getType());
276 ShapedType outputType = llvm::cast<ShapedType>(Val: output.getType());
277
278 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
279 return failure();
280 }
281
282 // If the output and input shapes are 1x1, then this is a no op.
283 ArrayRef<int64_t> outputShape = outputType.getShape();
284 if (outputShape[1] != 1 || outputShape[2] != 1) {
285 return failure();
286 }
287
288 ArrayRef<int64_t> inputShape = inputType.getShape();
289 if (inputShape[1] != 1 || inputShape[2] != 1) {
290 return failure();
291 }
292
293 rewriter.replaceOp(op, newValues: input);
294 return success();
295 }
296};
297
298void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
299 MLIRContext *context) {
300 results.add<MaxPool2dIsNoOp,
301 FoldPadToTensorOp<tosa::MaxPool2dOp,
302 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
303 arg&: context);
304}
305
306//===----------------------------------------------------------------------===//
307// Data Layout / Memory Reinterpretation.
308//===----------------------------------------------------------------------===//
309
310struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
311 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
312
313 LogicalResult matchAndRewrite(tosa::ConcatOp op,
314 PatternRewriter &rewriter) const override {
315 if (op.getInput1().size() != 1)
316 return failure();
317 if (op.getInput1().front().getType() != op.getType()) {
318 rewriter
319 .replaceOpWithNewOp<tensor::CastOp>(op, args: op.getType(),
320 args: op.getInput1().front())
321 .getResult();
322 return success();
323 }
324
325 rewriter.replaceOp(op, newValues: op.getInput1().front());
326 return success();
327 }
328};
329
330void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
331 MLIRContext *context) {
332 results.add<ConcatOptimization>(arg&: context);
333}
334
335LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
336 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
337 if (!notOp)
338 return failure();
339 rewriter.modifyOpInPlace(root: op, callable: [&]() {
340 op.getOperation()->setOperands(
341 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
342 });
343 return success();
344}
345
346struct ConsolidateTransposeOptimization
347 : public OpRewritePattern<tosa::TransposeOp> {
348 using OpRewritePattern::OpRewritePattern;
349
350 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
351 PatternRewriter &rewriter) const override {
352 // Input is also TransposeOp - transpose(transpose(A)).
353 auto innerTranspose =
354 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
355 if (!innerTranspose)
356 return rewriter.notifyMatchFailure(arg&: transposeOp,
357 msg: "input must be transpose operation");
358
359 const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
360 const llvm::ArrayRef<int32_t> innerTransposePerms =
361 innerTranspose.getPerms();
362
363 if (transposePerms.size() != innerTransposePerms.size())
364 return rewriter.notifyMatchFailure(
365 arg&: transposeOp,
366 msg: "transpose and inner transpose perms sizes must be equal");
367 if (transposePerms.empty())
368 return rewriter.notifyMatchFailure(
369 arg&: transposeOp, msg: "transpose perms sizes must be positive");
370
371 // Consolidate transposes into one transpose.
372 SmallVector<int32_t> perms(transposePerms.size());
373 for (int i = 0, s = transposePerms.size(); i < s; ++i)
374 perms[i] = innerTransposePerms[transposePerms[i]];
375
376 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
377 op: transposeOp, args: transposeOp.getResult().getType(),
378 args: innerTranspose.getInput1(), args: rewriter.getDenseI32ArrayAttr(values: perms));
379
380 return success();
381 }
382};
383
384// Determines the case when tosa.transpose is a tosa.reshape operation.
385struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
386 using OpRewritePattern::OpRewritePattern;
387
388 LogicalResult matchAndRewrite(tosa::TransposeOp op,
389 PatternRewriter &rewriter) const override {
390 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
391 return rewriter.notifyMatchFailure(
392 arg&: op, msg: "Src is from transpose, can compose transposes");
393
394 Value result = op.getResult();
395 for (Operation *subop : result.getUsers()) {
396 if (isa_and_nonnull<tosa::TransposeOp>(Val: subop))
397 return rewriter.notifyMatchFailure(
398 arg&: op, msg: "Dest is used by transpose, can compose transposes");
399 }
400
401 auto input = op.getInput1();
402 auto inputTy = llvm::cast<ShapedType>(Val: input.getType());
403 if (!inputTy.hasRank())
404 return rewriter.notifyMatchFailure(arg&: op, msg: "Unranked input.");
405
406 int64_t numDynDims = 0;
407 for (int i = 0; i < inputTy.getRank(); ++i)
408 if (inputTy.isDynamicDim(idx: i))
409 numDynDims++;
410
411 if (numDynDims > 1)
412 return rewriter.notifyMatchFailure(arg&: op, msg: "Has more than one dynamic dim.");
413
414 const llvm::ArrayRef<int32_t> permValues = op.getPerms();
415
416 SmallVector<int64_t> nonZeroPerms;
417 nonZeroPerms.reserve(N: permValues.size());
418 for (auto idx : permValues) {
419 auto sz = inputTy.getDimSize(idx);
420 if (sz != 1)
421 nonZeroPerms.push_back(Elt: idx);
422 }
423
424 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
425 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
426 return rewriter.notifyMatchFailure(arg&: op,
427 msg: "Transpose changes memory layout.");
428
429 SmallVector<int64_t> newShape;
430 newShape.reserve(N: inputTy.getRank());
431 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
432 newShape.push_back(Elt: inputTy.getDimSize(idx: permValues[i]));
433
434 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
435 op, args: op.getType(), args: op.getInput1(),
436 args: getTosaConstShape(rewriter, loc: op.getLoc(), shape: newShape));
437 return success();
438 }
439};
440
441void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
442 MLIRContext *context) {
443 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(arg&: context);
444}
445
446struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
447 using OpRewritePattern::OpRewritePattern;
448
449 LogicalResult matchAndRewrite(tosa::ClampOp op,
450 PatternRewriter &rewriter) const override {
451 Value input = op.getInput();
452 auto inputType = llvm::dyn_cast<RankedTensorType>(Val: op.getInput().getType());
453 auto inputElementType = inputType.getElementType();
454
455 if (!inputType.hasStaticShape()) {
456 return failure();
457 }
458
459 if (isa<FloatType>(Val: inputElementType)) {
460 // Unlike integer types, floating point types can represent infinity.
461 auto minClamp =
462 llvm::cast<mlir::FloatAttr>(Val: op.getMinValAttr()).getValue();
463 auto maxClamp =
464 llvm::cast<mlir::FloatAttr>(Val: op.getMaxValAttr()).getValue();
465 bool isMin = minClamp.isNegInfinity();
466 bool isMax = maxClamp.isInfinity();
467
468 if (isMin && isMax) {
469 rewriter.replaceOp(op, newValues: input);
470 return success();
471 }
472 return failure();
473 }
474
475 if (inputElementType.isUnsignedInteger()) {
476 int64_t minClamp =
477 llvm::cast<mlir::IntegerAttr>(Val: op.getMinValAttr()).getUInt();
478 int64_t maxClamp =
479 llvm::cast<mlir::IntegerAttr>(Val: op.getMaxValAttr()).getUInt();
480
481 int64_t intMin =
482 APInt::getMinValue(numBits: inputElementType.getIntOrFloatBitWidth())
483 .getZExtValue();
484 int64_t intMax =
485 APInt::getMaxValue(numBits: inputElementType.getIntOrFloatBitWidth())
486 .getZExtValue();
487
488 if (minClamp <= intMin && maxClamp >= intMax) {
489 rewriter.replaceOp(op, newValues: input);
490 return success();
491 }
492 return failure();
493 }
494
495 if (llvm::isa<IntegerType>(Val: inputElementType)) {
496 int64_t minClamp =
497 llvm::cast<mlir::IntegerAttr>(Val: op.getMinValAttr()).getInt();
498 int64_t maxClamp =
499 llvm::cast<mlir::IntegerAttr>(Val: op.getMaxValAttr()).getInt();
500
501 int64_t intMin =
502 APInt::getSignedMinValue(numBits: inputElementType.getIntOrFloatBitWidth())
503 .getSExtValue();
504 int64_t intMax =
505 APInt::getSignedMaxValue(numBits: inputElementType.getIntOrFloatBitWidth())
506 .getSExtValue();
507
508 if (minClamp <= intMin && maxClamp >= intMax) {
509 rewriter.replaceOp(op, newValues: input);
510 return success();
511 }
512 return failure();
513 }
514
515 return failure();
516 }
517};
518
519// Attempts the following transformation:
520//
521// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
522// tensor X the following identity holds:
523//
524// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
525//
526// subject to the following valid NaN propagation semantics:
527// --------------------------------------------
528// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
529// |-------------|--------------|-------------|
530// | PROPAGATE | PROPAGATE | PROPAGATE |
531// | PROPAGATE | IGNORE | IGNORE |
532// | IGNORE | PROPAGATE | INVALID |
533// | IGNORE | IGNORE | IGNORE |
534// |------------------------------------------|
535
536struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
537 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
538
539 // Helper structure to describe the range of a clamp operation.
540 template <typename T>
541 struct ClampRange {
542 ClampRange(const T &start, const T &end) : start(start), end(end) {}
543 T start;
544 T end;
545
546 // Helper function to determine if two Clamp ranges intersect.
547 bool intersects(const ClampRange<T> &otherRange) {
548 return start < otherRange.end && otherRange.start < end;
549 }
550 };
551
552 LogicalResult matchAndRewrite(tosa::ClampOp op,
553 PatternRewriter &rewriter) const override {
554 Value input = op.getInput();
555
556 // Check the input to the CLAMP op is itself a CLAMP.
557 auto clampOp = dyn_cast_if_present<tosa::ClampOp>(Val: input.getDefiningOp());
558 if (!clampOp)
559 return failure();
560
561 // Check we have a valid NaN propagation combination.
562 const auto opNanMode = op.getNanMode();
563 const auto clampNanMode = clampOp.getNanMode();
564 if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
565 return failure();
566
567 auto maxValAttr = op.getMaxValAttr();
568 auto minValAttr = op.getMinValAttr();
569 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
570 auto clampOpMinValAttr = clampOp.getMinValAttr();
571
572 auto inputEType = llvm::cast<ShapedType>(Val: input.getType()).getElementType();
573 if (auto quantType =
574 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(Val&: inputEType)) {
575 inputEType = quantType.getStorageType();
576 }
577
578 Attribute newMinValAttr, newMaxValAttr;
579 if (mlir::isa<FloatType>(Val: inputEType)) {
580 auto floatMaxValAttr = cast<mlir::FloatAttr>(Val&: maxValAttr);
581 auto floatMinValAttr = cast<mlir::FloatAttr>(Val&: minValAttr);
582 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(Val&: clampOpMaxValAttr);
583 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(Val&: clampOpMinValAttr);
584
585 // Check we have intersecting ranges.
586 const auto opMinFloat = floatMinValAttr.getValue();
587 const auto opMaxFloat = floatMaxValAttr.getValue();
588 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
589 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
590 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
591 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
592 clampOpMaxFloat);
593 if (!opRangeFloatRange.intersects(otherRange: clampRangeFloatRange))
594 return failure();
595
596 // Run the transformation.
597 auto newMinVal = std::max(a: opMinFloat, b: clampOpMinFloat);
598 auto newMaxVal = std::min(a: opMaxFloat, b: clampOpMaxFloat);
599 newMinValAttr = rewriter.getFloatAttr(type: inputEType, value: newMinVal);
600 newMaxValAttr = rewriter.getFloatAttr(type: inputEType, value: newMaxVal);
601 } else {
602 assert(mlir::isa<IntegerType>(inputEType));
603 auto intMaxValAttr = cast<mlir::IntegerAttr>(Val&: maxValAttr);
604 auto intMinValAttr = cast<mlir::IntegerAttr>(Val&: minValAttr);
605 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(Val&: clampOpMaxValAttr);
606 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(Val&: clampOpMinValAttr);
607
608 if (inputEType.isUnsignedInteger()) {
609 // Check we have intersecting ranges.
610 const auto opMinInt = intMinValAttr.getUInt();
611 const auto opMaxInt = intMaxValAttr.getUInt();
612 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
613 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
614 ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
615 ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
616 clampOpMaxInt);
617 if (!opRangeIntRange.intersects(otherRange: clampRangeIntRange))
618 return failure();
619
620 // Run the transformation.
621 auto newMinVal = std::max(a: opMinInt, b: clampOpMinInt);
622 auto newMaxVal = std::min(a: opMaxInt, b: clampOpMaxInt);
623 newMinValAttr = rewriter.getIntegerAttr(type: inputEType, value: newMinVal);
624 newMaxValAttr = rewriter.getIntegerAttr(type: inputEType, value: newMaxVal);
625 } else {
626 // Check we have intersecting ranges.
627 const auto opMinInt = intMinValAttr.getInt();
628 const auto opMaxInt = intMaxValAttr.getInt();
629 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
630 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
631 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
632 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
633 clampOpMaxInt);
634 if (!opRangeIntRange.intersects(otherRange: clampRangeIntRange))
635 return failure();
636
637 // Run the transformation.
638 auto newMinVal = std::max(a: opMinInt, b: clampOpMinInt);
639 auto newMaxVal = std::min(a: opMaxInt, b: clampOpMaxInt);
640 newMinValAttr = rewriter.getIntegerAttr(type: inputEType, value: newMinVal);
641 newMaxValAttr = rewriter.getIntegerAttr(type: inputEType, value: newMaxVal);
642 }
643 }
644
645 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
646 op, args: op.getType(), args: clampOp.getInput(), args&: newMinValAttr, args&: newMaxValAttr,
647 args: rewriter.getStringAttr(bytes: (opNanMode != clampNanMode) ? "IGNORE"
648 : opNanMode));
649 return success();
650 }
651};
652
653void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
654 MLIRContext *context) {
655 results.add<ClampIsNoOp>(arg&: context);
656 results.add<ClampClampOptimization>(arg&: context);
657}
658
659struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
660 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
661
662 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
663 PatternRewriter &rewriter) const override {
664 Value sliceInput = sliceOp.getInput1();
665 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
666 if (!concatOp)
667 return rewriter.notifyMatchFailure(
668 arg&: sliceOp, msg: "slice input must be concat operation");
669
670 OperandRange inputs = concatOp.getInput1();
671 auto concatType = dyn_cast<RankedTensorType>(Val: concatOp.getType());
672 if (!concatType || !concatType.hasStaticShape())
673 return rewriter.notifyMatchFailure(
674 arg&: sliceOp, msg: "slice input must be a static ranked tensor");
675 int32_t axis = concatOp.getAxis();
676
677 DenseElementsAttr startElems;
678 DenseElementsAttr sizeElems;
679
680 if (!matchPattern(value: sliceOp.getStart(), pattern: m_Constant(bind_value: &startElems)))
681 return rewriter.notifyMatchFailure(
682 arg&: sliceOp, msg: "start of slice must be a static ranked shape");
683
684 if (!matchPattern(value: sliceOp.getSize(), pattern: m_Constant(bind_value: &sizeElems)))
685 return rewriter.notifyMatchFailure(
686 arg&: sliceOp, msg: "size of slice must be a static ranked shape");
687
688 llvm::SmallVector<int64_t> sliceStarts =
689 llvm::to_vector(Range: startElems.getValues<int64_t>());
690 llvm::SmallVector<int64_t> sliceSizes =
691 llvm::to_vector(Range: sizeElems.getValues<int64_t>());
692
693 // Validate slice on the concatenated axis. Slicing along this
694 // axis should span only one of the inputs to the concatenate
695 // operation.
696 std::optional<Value> replaceWithSlice;
697 for (auto input : inputs) {
698 auto inputType = dyn_cast<RankedTensorType>(Val: input.getType());
699 if (!inputType || !inputType.hasStaticShape())
700 return rewriter.notifyMatchFailure(
701 arg&: sliceOp, msg: "concat input must be a static ranked tensor");
702
703 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
704 inputType.getDimSize(idx: axis)) {
705 auto start_op =
706 getTosaConstShape(rewriter, loc: sliceOp.getLoc(), shape: sliceStarts);
707 auto size_op =
708 getTosaConstShape(rewriter, loc: sliceOp.getLoc(), shape: sliceSizes);
709 replaceWithSlice =
710 rewriter
711 .create<tosa::SliceOp>(location: sliceOp.getLoc(), args: sliceOp.getType(),
712 args&: input, args&: start_op, args&: size_op)
713 .getResult();
714 break;
715 }
716 sliceStarts[axis] -= inputType.getDimSize(idx: axis);
717 }
718
719 if (!replaceWithSlice)
720 return rewriter.notifyMatchFailure(
721 arg&: sliceOp, msg: "corresponding concat input not found for slice");
722
723 rewriter.replaceOp(op: sliceOp, newValues: replaceWithSlice.value());
724 return success();
725 }
726};
727
728struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
729 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
730
731 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
732 PatternRewriter &rewriter) const override {
733 Value sliceInput = sliceOp.getInput1();
734
735 // Check if producer is a PadOp
736 auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
737 if (!padOp)
738 return rewriter.notifyMatchFailure(arg&: sliceOp,
739 msg: "slice input must be a pad operation");
740
741 // Check PadOp has a single consumer
742 if (!padOp->hasOneUse())
743 return rewriter.notifyMatchFailure(arg&: sliceOp,
744 msg: "pad shall have a single consumer");
745
746 // Check input is statically ranked
747 auto inputTy = dyn_cast<RankedTensorType>(Val: padOp.getInput1().getType());
748 auto padTy = dyn_cast<RankedTensorType>(Val: padOp.getType());
749 if (!inputTy || !padTy || !inputTy.hasRank())
750 return rewriter.notifyMatchFailure(arg&: sliceOp,
751 msg: "slice input must be a ranked tensor");
752
753 // Validate and extract tosa::PadOp padding
754 DenseIntElementsAttr paddingElems;
755 if (!matchPattern(value: padOp.getPadding(), pattern: m_Constant(bind_value: &paddingElems))) {
756 return rewriter.notifyMatchFailure(
757 arg&: sliceOp,
758 msg: "`padding` input specified on the tosa::PadOp must be constant.");
759 }
760 llvm::SmallVector<int64_t> padPaddings =
761 llvm::to_vector(Range: paddingElems.getValues<int64_t>());
762
763 // Extract slice parameters
764 DenseElementsAttr startElems;
765 if (!matchPattern(value: sliceOp.getStart(), pattern: m_Constant(bind_value: &startElems)))
766 return rewriter.notifyMatchFailure(
767 arg&: sliceOp, msg: "start of slice must be a static ranked shape");
768 llvm::SmallVector<int64_t> sliceStarts =
769 llvm::to_vector(Range: startElems.getValues<int64_t>());
770
771 DenseElementsAttr sizeElems;
772 if (!matchPattern(value: sliceOp.getSize(), pattern: m_Constant(bind_value: &sizeElems)))
773 return rewriter.notifyMatchFailure(
774 arg&: sliceOp, msg: "size of slice must be a static ranked shape");
775 llvm::SmallVector<int64_t> sliceSizes =
776 llvm::to_vector(Range: sizeElems.getValues<int64_t>());
777
778 // Check if dynamic dimensions are sliced
779 const int64_t rank = inputTy.getRank();
780 if (llvm::any_of(Range: llvm::seq<int64_t>(Begin: 0, End: rank), P: [&](int64_t i) {
781 const bool isDimDynamic = inputTy.isDynamicDim(idx: i);
782 const bool isDimSliced =
783 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
784
785 return isDimDynamic && isDimSliced;
786 })) {
787 return rewriter.notifyMatchFailure(
788 arg&: sliceOp, msg: "axis that are sliced shall be statically known.");
789 }
790
791 // Update the parameters
792 llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
793 llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
794 llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
795 bool updated = false;
796
797 for (int64_t i = 0; i < rank; ++i) {
798 const int64_t padLo = padPaddings[i * 2];
799 const int64_t padHi = padPaddings[i * 2 + 1];
800 const int64_t sliceStart = sliceStarts[i];
801 const int64_t sliceSize = sliceSizes[i];
802 const int64_t sliceEnd = sliceStart + sliceSize;
803
804 // If dimension is dynamic pass-through
805 if (inputTy.isDynamicDim(idx: i)) {
806 newPadPaddings[i * 2] = padLo;
807 newPadPaddings[i * 2 + 1] = padHi;
808 newSliceStarts[i] = sliceStart;
809 continue;
810 }
811
812 // Handle static dimensions
813 const int64_t dimSize = inputTy.getShape()[i];
814 const int64_t dimTotal = padLo + dimSize + padHi;
815
816 // Check slice within bounds
817 if (sliceStart < 0 || sliceEnd > dimTotal)
818 return rewriter.notifyMatchFailure(arg&: sliceOp, msg: "slice is out-of-bounds");
819
820 // Compute updated slice start parameter
821 const int64_t newSliceStart = std::max<int64_t>(a: sliceStart - padLo, b: 0);
822 newSliceStarts[i] = newSliceStart;
823 updated |= newSliceStart != sliceStart;
824
825 // Compute updated pad parameters
826 const int64_t newPadLo = std::max<int64_t>(a: padLo - sliceStart, b: 0);
827 const int64_t newPadHi =
828 std::max<int64_t>(a: sliceEnd - (padLo + dimSize), b: 0);
829 newPadPaddings[i * 2] = newPadLo;
830 newPadPaddings[i * 2 + 1] = newPadHi;
831 updated |= (newPadLo != padLo) || (newPadHi != padHi);
832
833 // Calculate new pad output shape
834 newPadShape[i] =
835 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
836 }
837
838 // Check that we actually need to proceed with the rewrite
839 if (!updated)
840 return rewriter.notifyMatchFailure(
841 arg&: sliceOp, msg: "terminate condition; nothing to rewrite");
842
843 // Create a PadOp with updated padding
844 auto newPaddingsOp =
845 getTosaConstShape(rewriter, loc: sliceOp.getLoc(), shape: newPadPaddings);
846 auto newPadTy =
847 RankedTensorType::get(shape: newPadShape, elementType: inputTy.getElementType());
848 auto newPadOp = rewriter.create<tosa::PadOp>(
849 location: padOp.getLoc(), args&: newPadTy, args: padOp.getInput1(), args&: newPaddingsOp,
850 args: padOp.getPadConst());
851
852 // Update SliceOp and point to new PadOp
853 auto newStartOp =
854 getTosaConstShape(rewriter, loc: sliceOp.getLoc(), shape: newSliceStarts);
855 rewriter.replaceOpWithNewOp<tosa::SliceOp>(op: sliceOp, args: sliceOp.getType(),
856 args: newPadOp.getResult(), args&: newStartOp,
857 args: sliceOp.getSize());
858
859 return success();
860 }
861};
862
863// Update size operand of tosa.slice if size has dynamic dims but corresponding
864// output dim is static
865struct SliceDynamicSizeCanonicalization
866 : public OpRewritePattern<tosa::SliceOp> {
867 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
868
869 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
870 PatternRewriter &rewriter) const override {
871 ShapedType resultType = cast<ShapedType>(Val: sliceOp.getType());
872
873 ElementsAttr sizeElems;
874 if (!matchPattern(value: sliceOp.getSize(), pattern: m_Constant(bind_value: &sizeElems))) {
875 return rewriter.notifyMatchFailure(
876 arg&: sliceOp, msg: "size of slice must be a static ranked shape");
877 }
878
879 llvm::SmallVector<int64_t> sliceSizes =
880 llvm::to_vector(Range: sizeElems.getValues<int64_t>());
881
882 bool replaceSliceSize{false};
883 // if size op has -1 indicating dynamic shape but corresponding dim on the
884 // output is statically known, update size to match with known output dim
885 // shape
886 for (const auto &[index, size] : llvm::enumerate(First&: sliceSizes)) {
887 if (size == -1 && !resultType.isDynamicDim(idx: index)) {
888 sliceSizes[index] = resultType.getDimSize(idx: index);
889 replaceSliceSize = true;
890 }
891 }
892
893 if (!replaceSliceSize) {
894 return rewriter.notifyMatchFailure(
895 arg&: sliceOp, msg: "no dimension of size of slice is dynamic that resolves "
896 "to static output shape");
897 }
898
899 auto size_op = getTosaConstShape(rewriter, loc: sliceOp.getLoc(), shape: sliceSizes);
900 auto newSliceOp = rewriter.create<tosa::SliceOp>(
901 location: sliceOp.getLoc(), args: sliceOp.getType(), args: sliceOp.getInput1(),
902 args: sliceOp.getStart(), args&: size_op);
903
904 rewriter.replaceOp(op: sliceOp, newValues: newSliceOp.getResult());
905 return success();
906 }
907};
908
909void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
910 MLIRContext *context) {
911 results.add<ConcatSliceOptimization, PadSliceOptimization,
912 SliceDynamicSizeCanonicalization>(arg&: context);
913}
914
915//===----------------------------------------------------------------------===//
916// Operator Folders.
917//===----------------------------------------------------------------------===//
918
919template <typename IntFolder, typename FloatFolder>
920DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
921 RankedTensorType returnTy) {
922 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
923 auto lETy = llvm::cast<ShapedType>(Val: lhs.getType()).getElementType();
924 auto rETy = llvm::cast<ShapedType>(Val: rhs.getType()).getElementType();
925 if (lETy != rETy)
926 return {};
927
928 if (llvm::isa<IntegerType>(Val: lETy)) {
929 APInt l = lhs.getSplatValue<APInt>();
930 APInt r = rhs.getSplatValue<APInt>();
931 auto result = IntFolder()(l, r);
932 return DenseElementsAttr::get(returnTy, result);
933 }
934
935 if (llvm::isa<FloatType>(Val: lETy)) {
936 APFloat l = lhs.getSplatValue<APFloat>();
937 APFloat r = rhs.getSplatValue<APFloat>();
938 auto result = FloatFolder()(l, r);
939 return DenseElementsAttr::get(returnTy, result);
940 }
941 }
942
943 return {};
944}
945
946static bool isSplatZero(Type elemType, DenseElementsAttr val) {
947 if (llvm::isa<FloatType>(Val: elemType))
948 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
949 if (llvm::isa<IntegerType>(Val: elemType))
950 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
951 return false;
952}
953
954static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
955 if (llvm::isa<FloatType>(Val: elemType))
956 return val && val.isSplat() &&
957 val.getSplatValue<APFloat>().isExactlyValue(V: 1.0);
958 if (llvm::isa<IntegerType>(Val: elemType)) {
959 const int64_t shifted = 1LL << shift;
960 return val && val.isSplat() &&
961 val.getSplatValue<APInt>().getSExtValue() == shifted;
962 }
963 return false;
964}
965
966OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
967 auto lhsTy = llvm::dyn_cast<RankedTensorType>(Val: getInput1().getType());
968 auto rhsTy = llvm::dyn_cast<RankedTensorType>(Val: getInput2().getType());
969 auto resultTy = llvm::dyn_cast<RankedTensorType>(Val: getType());
970 if (!lhsTy || !rhsTy || !resultTy)
971 return {};
972
973 // Cannot create an ElementsAttr from non-int/float/index types
974 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
975 !rhsTy.getElementType().isIntOrIndexOrFloat())
976 return {};
977
978 auto resultETy = resultTy.getElementType();
979 auto lhsAttr =
980 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput1());
981 auto rhsAttr =
982 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput2());
983
984 if (lhsTy == resultTy && isSplatZero(elemType: resultETy, val: rhsAttr))
985 return getInput1();
986 if (rhsTy == resultTy && isSplatZero(elemType: resultETy, val: lhsAttr))
987 return getInput2();
988
989 if (!lhsAttr || !rhsAttr)
990 return {};
991
992 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhs: lhsAttr, rhs: rhsAttr,
993 returnTy: resultTy);
994}
995
996OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
997 auto inputTy = llvm::dyn_cast<RankedTensorType>(Val: getInput().getType());
998 auto outputTy = llvm::dyn_cast<RankedTensorType>(Val: getType());
999 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1000 !outputTy.hasStaticShape())
1001 return {};
1002
1003 if (inputTy.getDimSize(idx: getAxis()) == 1)
1004 return DenseElementsAttr::get(type: outputTy, value: 0);
1005
1006 return {};
1007}
1008
1009OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1010 auto lhsTy = llvm::dyn_cast<RankedTensorType>(Val: getInput1().getType());
1011 auto rhsTy = llvm::dyn_cast<RankedTensorType>(Val: getInput2().getType());
1012 auto resultTy = llvm::dyn_cast<RankedTensorType>(Val: getType());
1013 if (!lhsTy || !rhsTy || !resultTy)
1014 return {};
1015 if (lhsTy != rhsTy)
1016 return {};
1017
1018 // IntDivOp inputs must be integer type, no need to check for quantized type
1019 auto resultETy = resultTy.getElementType();
1020 auto lhsAttr =
1021 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput1());
1022 auto rhsAttr =
1023 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput2());
1024 if (lhsAttr && lhsAttr.isSplat()) {
1025 if (llvm::isa<IntegerType>(Val: resultETy) &&
1026 lhsAttr.getSplatValue<APInt>().isZero())
1027 return lhsAttr;
1028 }
1029
1030 if (rhsAttr && rhsAttr.isSplat()) {
1031 if (llvm::isa<IntegerType>(Val: resultETy) &&
1032 rhsAttr.getSplatValue<APInt>().isOne())
1033 return getInput1();
1034 }
1035
1036 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1037 llvm::isa<IntegerType>(Val: resultETy)) {
1038 APInt l = lhsAttr.getSplatValue<APInt>();
1039 APInt r = rhsAttr.getSplatValue<APInt>();
1040 if (!r.isZero()) {
1041 APInt result = l.sdiv(RHS: r);
1042 return DenseElementsAttr::get(type: resultTy, values: result);
1043 }
1044 }
1045
1046 return {};
1047}
1048
1049namespace {
1050// calculate lhs * rhs >> shift according to TOSA Spec
1051// return nullopt if result is not in range of int32_t when shift > 0
1052std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1053 unsigned bitwidth) {
1054 APInt result = lhs.sext(width: 64) * rhs.sext(width: 64);
1055
1056 if (shift > 0) {
1057 auto round = APInt(64, 1) << (shift - 1);
1058 result += round;
1059 result.ashrInPlace(ShiftAmt: shift);
1060 // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
1061 if (!(result.getSExtValue() >= INT32_MIN &&
1062 result.getSExtValue() <= INT32_MAX)) {
1063 // REQUIRE failed
1064 return std::nullopt;
1065 }
1066 }
1067
1068 return result.trunc(width: bitwidth);
1069}
1070
1071DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1072 RankedTensorType ty, int32_t shift) {
1073 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1074 if (llvm::isa<IntegerType>(Val: ty.getElementType())) {
1075 APInt l = lhs.getSplatValue<APInt>();
1076 APInt r = rhs.getSplatValue<APInt>();
1077
1078 if (shift == 0) {
1079 return DenseElementsAttr::get(type: ty, values: l * r);
1080 }
1081
1082 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1083 const std::optional<APInt> result = mulInt(lhs: l, rhs: r, shift, bitwidth);
1084 if (!result)
1085 return {};
1086 return DenseElementsAttr::get(type: ty, values: result.value());
1087 }
1088
1089 if (llvm::isa<FloatType>(Val: ty.getElementType())) {
1090 APFloat l = lhs.getSplatValue<APFloat>();
1091 APFloat r = rhs.getSplatValue<APFloat>();
1092 APFloat result = l * r;
1093 return DenseElementsAttr::get(type: ty, values: result);
1094 }
1095 }
1096
1097 return {};
1098}
1099} // namespace
1100
1101OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1102 auto lhs = getInput1();
1103 auto rhs = getInput2();
1104 auto lhsTy = llvm::dyn_cast<RankedTensorType>(Val: lhs.getType());
1105 auto rhsTy = llvm::dyn_cast<RankedTensorType>(Val: rhs.getType());
1106 auto resultTy = llvm::dyn_cast<RankedTensorType>(Val: getType());
1107 if (!lhsTy || !rhsTy || !resultTy)
1108 return {};
1109
1110 auto resultETy = resultTy.getElementType();
1111 auto lhsAttr =
1112 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput1());
1113 auto rhsAttr =
1114 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput2());
1115
1116 // Result right shift on i32_t data type only. For simplification, synthesize
1117 // a zero shift for other data type.
1118 int32_t shift = 0;
1119 if (resultETy.isInteger(width: 32)) {
1120 ElementsAttr shift_elem;
1121 if (getShift().getImpl()) {
1122 if (!matchPattern(value: getShift(), pattern: m_Constant(bind_value: &shift_elem)))
1123 // cannot be folded when the shift value is unknown.
1124 return {};
1125 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1126 }
1127 }
1128
1129 if (rhsTy == resultTy) {
1130 if (isSplatZero(elemType: resultETy, val: lhsAttr))
1131 return lhsAttr.resizeSplat(newType: resultTy);
1132 if (isSplatOne(elemType: resultETy, val: lhsAttr, shift))
1133 return rhs;
1134 }
1135 if (lhsTy == resultTy) {
1136 if (isSplatZero(elemType: resultETy, val: rhsAttr))
1137 return rhsAttr.resizeSplat(newType: resultTy);
1138 if (isSplatOne(elemType: resultETy, val: rhsAttr, shift))
1139 return lhs;
1140 }
1141
1142 return mulBinaryFolder(lhs: lhsAttr, rhs: rhsAttr, ty: resultTy, shift);
1143}
1144
1145OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1146 auto lhsTy = llvm::dyn_cast<RankedTensorType>(Val: getInput1().getType());
1147 auto rhsTy = llvm::dyn_cast<RankedTensorType>(Val: getInput2().getType());
1148 auto resultTy = llvm::dyn_cast<RankedTensorType>(Val: getType());
1149 if (!lhsTy || !rhsTy || !resultTy)
1150 return {};
1151
1152 // Cannot create an ElementsAttr from non-int/float/index types
1153 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1154 !rhsTy.getElementType().isIntOrIndexOrFloat())
1155 return {};
1156
1157 auto resultETy = resultTy.getElementType();
1158 auto lhsAttr =
1159 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput1());
1160 auto rhsAttr =
1161 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput2());
1162
1163 if (lhsTy == resultTy && isSplatZero(elemType: resultETy, val: rhsAttr))
1164 return getInput1();
1165
1166 if (!lhsAttr || !rhsAttr)
1167 return {};
1168
1169 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhs: lhsAttr, rhs: rhsAttr,
1170 returnTy: resultTy);
1171}
1172
1173namespace {
1174template <typename Cmp>
1175struct ComparisonFold {
1176 ComparisonFold() = default;
1177 APInt operator()(const APInt &l, const APInt &r) {
1178 return APInt(1, Cmp()(l, r));
1179 }
1180
1181 APInt operator()(const APFloat &l, const APFloat &r) {
1182 return APInt(1, Cmp()(l, r));
1183 }
1184};
1185
1186struct APIntFoldGreater {
1187 APIntFoldGreater() = default;
1188 APInt operator()(const APInt &l, const APInt &r) {
1189 return APInt(1, l.sgt(RHS: r));
1190 }
1191};
1192
1193struct APIntFoldGreaterEqual {
1194 APIntFoldGreaterEqual() = default;
1195 APInt operator()(const APInt &l, const APInt &r) {
1196 return APInt(1, l.sge(RHS: r));
1197 }
1198};
1199} // namespace
1200
1201OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1202 auto resultTy = llvm::dyn_cast<RankedTensorType>(Val: getType());
1203 auto lhsAttr =
1204 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput1());
1205 auto rhsAttr =
1206 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput2());
1207
1208 if (!lhsAttr || !rhsAttr)
1209 return {};
1210
1211 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
1212 lhs: lhsAttr, rhs: rhsAttr, returnTy: resultTy);
1213}
1214
1215OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1216 auto resultTy = llvm::dyn_cast<RankedTensorType>(Val: getType());
1217 auto lhsAttr =
1218 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput1());
1219 auto rhsAttr =
1220 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput2());
1221
1222 if (!lhsAttr || !rhsAttr)
1223 return {};
1224
1225 return binaryFolder<APIntFoldGreaterEqual,
1226 ComparisonFold<std::greater_equal<APFloat>>>(
1227 lhs: lhsAttr, rhs: rhsAttr, returnTy: resultTy);
1228}
1229
1230OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1231 auto resultTy = llvm::dyn_cast<RankedTensorType>(Val: getType());
1232 auto lhsAttr =
1233 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput1());
1234 auto rhsAttr =
1235 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput2());
1236 Value lhs = getInput1();
1237 Value rhs = getInput2();
1238 auto lhsTy = llvm::cast<ShapedType>(Val: lhs.getType());
1239
1240 // If we are comparing an integer value to itself it is always true. We can
1241 // not do this with float due to float values.
1242 if (llvm::isa<IntegerType>(Val: lhsTy.getElementType()) && resultTy &&
1243 resultTy.hasStaticShape() && lhs == rhs) {
1244 return DenseElementsAttr::get(type: resultTy, value: true);
1245 }
1246
1247 if (!lhsAttr || !rhsAttr)
1248 return {};
1249
1250 return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
1251 ComparisonFold<std::equal_to<APFloat>>>(lhs: lhsAttr, rhs: rhsAttr,
1252 returnTy: resultTy);
1253}
1254
1255OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1256 if (getInput().getType() == getType())
1257 return getInput();
1258
1259 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(Val: adaptor.getInput());
1260 if (!operand)
1261 return {};
1262
1263 auto inTy = llvm::cast<ShapedType>(Val: getInput().getType());
1264 auto outTy = llvm::cast<ShapedType>(Val: getType());
1265 auto inETy = inTy.getElementType();
1266 auto outETy = outTy.getElementType();
1267
1268 if (operand.isSplat()) {
1269 if (llvm::isa<FloatType>(Val: inETy) && llvm::isa<FloatType>(Val: outETy)) {
1270 bool overflow;
1271 auto splatVal = operand.getSplatValue<APFloat>();
1272 auto &semantics = llvm::cast<FloatType>(Val&: outETy).getFloatSemantics();
1273 splatVal.convert(ToSemantics: semantics, RM: llvm::RoundingMode::NearestTiesToEven,
1274 losesInfo: &overflow);
1275 return SplatElementsAttr::get(type: outTy, values: splatVal);
1276 }
1277
1278 if (llvm::isa<IntegerType>(Val: inETy) && llvm::isa<FloatType>(Val: outETy)) {
1279 auto unsign = llvm::cast<IntegerType>(Val&: inETy).isUnsignedInteger();
1280 APFloat splatVal(llvm::cast<FloatType>(Val&: outETy).getFloatSemantics());
1281 splatVal.convertFromAPInt(Input: operand.getSplatValue<APInt>(), IsSigned: !unsign,
1282 RM: llvm::RoundingMode::NearestTiesToEven);
1283 return SplatElementsAttr::get(type: outTy, values: splatVal);
1284 }
1285
1286 if (llvm::isa<FloatType>(Val: inETy) && llvm::isa<IntegerType>(Val: outETy)) {
1287 auto unsign = llvm::cast<IntegerType>(Val&: outETy).isUnsignedInteger();
1288 auto intVal = APSInt(
1289 llvm::cast<IntegerType>(Val&: outETy).getIntOrFloatBitWidth(), unsign);
1290 auto floatVal = operand.getSplatValue<APFloat>();
1291 bool exact;
1292 floatVal.convertToInteger(Result&: intVal, RM: llvm::RoundingMode::NearestTiesToEven,
1293 IsExact: &exact);
1294 return SplatElementsAttr::get(type: outTy, values: intVal);
1295 }
1296
1297 if (llvm::isa<IntegerType>(Val: inETy) && llvm::isa<IntegerType>(Val: outETy)) {
1298 auto unsignIn = llvm::cast<IntegerType>(Val&: inETy).isUnsignedInteger();
1299 bool trunc =
1300 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1301 auto intVal = operand.getSplatValue<APInt>();
1302 auto bitwidth = outETy.getIntOrFloatBitWidth();
1303
1304 if (trunc) {
1305 intVal = intVal.trunc(width: bitwidth);
1306 } else if (unsignIn) {
1307 intVal = intVal.zext(width: bitwidth);
1308 } else {
1309 intVal = intVal.sext(width: bitwidth);
1310 }
1311
1312 return SplatElementsAttr::get(type: outTy, values: intVal);
1313 }
1314 }
1315
1316 return {};
1317}
1318
1319OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1320
1321OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1322
1323#define REDUCE_FOLDER(OP) \
1324 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1325 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1326 if (!inputTy.hasRank()) \
1327 return {}; \
1328 if (inputTy != getType()) \
1329 return {}; \
1330 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1331 return getInput(); \
1332 return {}; \
1333 }
1334
1335REDUCE_FOLDER(ReduceAllOp)
1336REDUCE_FOLDER(ReduceAnyOp)
1337REDUCE_FOLDER(ReduceMaxOp)
1338REDUCE_FOLDER(ReduceMinOp)
1339REDUCE_FOLDER(ReduceProductOp)
1340REDUCE_FOLDER(ReduceSumOp)
1341#undef REDUCE_FOLDER
1342
1343OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1344 auto inputTy = llvm::dyn_cast<RankedTensorType>(Val: getInput1().getType());
1345 auto outputTy = llvm::dyn_cast<RankedTensorType>(Val: getType());
1346
1347 if (!inputTy || !outputTy)
1348 return {};
1349
1350 // Fold when the input and output types are the same. This is only safe when
1351 // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
1352 // there may still be a productive reshape.
1353 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1354 return getInput1();
1355
1356 // reshape(reshape(x)) -> reshape(x)
1357 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1358 Val: getInput1().getDefiningOp())) {
1359 getInput1Mutable().assign(value: reshapeOp.getInput1());
1360 return getResult();
1361 }
1362
1363 // Cannot create an ElementsAttr from non-int/float/index types
1364 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1365 return {};
1366
1367 // reshape(const(x)) -> const(reshape-attr(x))
1368 if (auto operand =
1369 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput1())) {
1370 // Constants must have static shape.
1371 if (!outputTy.hasStaticShape())
1372 return {};
1373
1374 // Okay to duplicate splat constants.
1375 if (operand.isSplat())
1376 return SplatElementsAttr::get(type: outputTy,
1377 values: operand.getSplatValue<Attribute>());
1378
1379 // Don't duplicate other constants.
1380 if (!getInput1().hasOneUse())
1381 return {};
1382
1383 llvm::SmallVector<int64_t> shapeVec;
1384 if (!tosa::getConstShapeValues(op: getShape().getDefiningOp(), result_shape&: shapeVec))
1385 return {};
1386
1387 return operand.reshape(
1388 newType: llvm::cast<ShapedType>(Val: operand.getType()).clone(shape: shapeVec));
1389 }
1390
1391 return {};
1392}
1393
1394OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1395 // If the pad is all zeros we can fold this operation away.
1396 if (adaptor.getPadding() && getInput1().getType() == getType()) {
1397 auto densePad = llvm::dyn_cast<DenseElementsAttr>(Val: adaptor.getPadding());
1398 if (densePad && densePad.isSplat() &&
1399 densePad.getSplatValue<APInt>().isZero()) {
1400 return getInput1();
1401 }
1402 }
1403
1404 return {};
1405}
1406
1407// Fold away cases where a tosa.resize operation returns a copy
1408// of the input image.
1409OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1410 auto scaleAttr =
1411 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getScale());
1412 auto offsetAttr =
1413 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getOffset());
1414 auto borderAttr =
1415 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getBorder());
1416 if (!scaleAttr || !offsetAttr || !borderAttr) {
1417 return {};
1418 }
1419
1420 auto scale = tosa::convertFromIntAttr(attr: scaleAttr, /* rank = */ 4);
1421 auto offset = tosa::convertFromIntAttr(attr: offsetAttr, /* rank = */ 2);
1422 auto border = tosa::convertFromIntAttr(attr: borderAttr, /* rank = */ 2);
1423 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1424 return {};
1425 }
1426
1427 // Check unit scaling.
1428 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1429 return {};
1430 }
1431
1432 // There should be no offset.
1433 if (offset[0] != 0 || offset[1] != 0) {
1434 return {};
1435 }
1436
1437 // There should be no border.
1438 if (border[0] != 0 || border[1] != 0) {
1439 return {};
1440 }
1441
1442 auto input = getInput();
1443 auto inputTy = llvm::cast<RankedTensorType>(Val: input.getType());
1444 auto resultTy = llvm::cast<RankedTensorType>(Val: getType());
1445 if (inputTy != resultTy)
1446 return {};
1447
1448 return input;
1449}
1450
1451OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1452 auto operand = getInput1();
1453 auto operandTy = llvm::cast<ShapedType>(Val: operand.getType());
1454 auto axis = getAxis();
1455 auto operandAttr =
1456 llvm::dyn_cast_if_present<SplatElementsAttr>(Val: adaptor.getInput1());
1457 if (operandAttr)
1458 return operandAttr;
1459
1460 // If the dim-length is 1, tosa.reverse is a no-op.
1461 if (operandTy.hasRank() &&
1462 (operandTy.getRank() == 0 || operandTy.getDimSize(idx: axis) == 1))
1463 return operand;
1464
1465 return {};
1466}
1467
1468OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1469 auto inputTy = llvm::dyn_cast<RankedTensorType>(Val: getInput1().getType());
1470 auto outputTy = llvm::dyn_cast<RankedTensorType>(Val: getType());
1471
1472 if (!inputTy || !outputTy)
1473 return {};
1474
1475 if (inputTy == outputTy && inputTy.hasStaticShape())
1476 return getInput1();
1477
1478 if (!adaptor.getInput1())
1479 return {};
1480
1481 // Cannot create an ElementsAttr from non-int/float/index types
1482 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1483 !outputTy.getElementType().isIntOrIndexOrFloat())
1484 return {};
1485
1486 auto operand = llvm::cast<ElementsAttr>(Val: adaptor.getInput1());
1487 if (operand.isSplat() && outputTy.hasStaticShape()) {
1488 return SplatElementsAttr::get(type: outputTy, values: operand.getSplatValue<Attribute>());
1489 }
1490
1491 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1492 outputTy.getNumElements() == 1) {
1493 DenseElementsAttr startElems;
1494 if (!matchPattern(value: getStart(), pattern: m_Constant(bind_value: &startElems)))
1495 return {};
1496
1497 llvm::SmallVector<uint64_t> indices =
1498 llvm::to_vector(Range: startElems.getValues<uint64_t>());
1499 auto value = operand.getValues<Attribute>()[indices];
1500 return SplatElementsAttr::get(type: outputTy, values: value);
1501 }
1502
1503 return {};
1504}
1505
1506OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1507 if (getOnTrue() == getOnFalse())
1508 return getOnTrue();
1509
1510 auto predicate =
1511 llvm::dyn_cast_if_present<DenseIntElementsAttr>(Val: adaptor.getInput1());
1512 if (!predicate)
1513 return {};
1514
1515 if (!predicate.isSplat())
1516 return {};
1517 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1518 : getOnFalse();
1519}
1520
1521OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1522 if (getInput1().getType() == getType()) {
1523 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1524 Val: adaptor.getMultiples())) {
1525 if (multiples.isSplat() &&
1526 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1527 return getInput1();
1528 if (auto int_array_attr =
1529 llvm::dyn_cast<DenseIntElementsAttr>(Val&: multiples)) {
1530 if (llvm::all_of(Range: int_array_attr.getValues<APInt>(),
1531 P: [](APInt v) { return v.getSExtValue() == 1; }))
1532 return getInput1();
1533 }
1534 }
1535 }
1536 return {};
1537}
1538
1539OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1540 auto resultTy = llvm::cast<ShapedType>(Val: getType());
1541
1542 // Transposing splat values just means reshaping.
1543 if (auto input =
1544 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getInput1())) {
1545 if (input.isSplat() && resultTy.hasStaticShape() &&
1546 input.getType().getElementType() == resultTy.getElementType())
1547 return input.reshape(newType: resultTy);
1548 }
1549
1550 // Transpose is not the identity transpose.
1551 const llvm::ArrayRef<int32_t> perms = getPerms();
1552
1553 if (!llvm::equal(LRange: llvm::seq<int32_t>(Begin: 0, End: perms.size()), RRange: perms))
1554 return {};
1555
1556 return getInput1();
1557}
1558
1559OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
1560 auto input = getInput1();
1561 // Element-wise log(exp(x)) = x
1562 if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
1563 return op.getInput1();
1564 }
1565
1566 return {};
1567}
1568
1569OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
1570 auto input = getInput1();
1571 // Element-wise exp(log(x)) = x
1572 if (auto op = input.getDefiningOp<tosa::LogOp>()) {
1573 return op.getInput1();
1574 }
1575
1576 return {};
1577}
1578
1579OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1580 // Element-wise negate(negate(x)) = x
1581 // iff all zero points are constant 0
1582 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1583 if (!definingOp) {
1584 // defining op of input1 is not a negate, cannot fold
1585 return {};
1586 }
1587
1588 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1589 failed(Result: maybeIZp) || *maybeIZp != 0) {
1590 // input1 zero point is not constant 0, cannot fold
1591 return {};
1592 }
1593 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1594 failed(Result: maybeOZp) || *maybeOZp != 0) {
1595 // output zero point is not constant 0, cannot fold
1596 return {};
1597 }
1598 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1599 failed(Result: maybeIZp) || *maybeIZp != 0) {
1600 // definingOp's input1 zero point is not constant 0, cannot fold
1601 return {};
1602 }
1603 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1604 failed(Result: maybeOZp) || *maybeOZp != 0) {
1605 // definingOp's output zero point is not constant 0, cannot fold
1606 return {};
1607 }
1608
1609 return definingOp.getInput1();
1610}
1611
1612OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1613 auto input = getInput1();
1614 // Element-wise abs(abs(x)) = abs(x)
1615 if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1616 return input;
1617 }
1618
1619 return {};
1620}
1621
1622OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1623 // Fold consecutive concats on the same axis into a single op.
1624 // Keep track of the operands so we are able to construct a new concat
1625 // later. Conservatively assume that we double the number of operands when
1626 // folding
1627 SmallVector<Value, 8> concatOperands;
1628 concatOperands.reserve(N: 2 * getNumOperands());
1629
1630 // Find all operands that are foldable concats
1631 bool foundFoldableConcat = false;
1632 for (Value operand : getOperands()) {
1633 concatOperands.emplace_back(Args&: operand);
1634
1635 auto producer = dyn_cast_or_null<ConcatOp>(Val: operand.getDefiningOp());
1636 if (!producer)
1637 continue;
1638
1639 // Not foldable if axes are not the same
1640 if (getAxis() != producer.getAxis())
1641 continue;
1642
1643 // Replace the original operand with all incoming operands
1644 foundFoldableConcat = true;
1645 concatOperands.pop_back();
1646 llvm::append_range(C&: concatOperands, R: producer->getOperands());
1647 }
1648
1649 if (!foundFoldableConcat)
1650 return {};
1651
1652 getOperation()->setOperands(concatOperands);
1653 return getResult();
1654}
1655
1656OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1657 auto input = adaptor.getInput1();
1658
1659 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(Val&: input);
1660 // Fold splat inputs only.
1661 if (!inputAttr || !inputAttr.isSplat())
1662 return {};
1663
1664 auto shapeType = llvm::cast<ShapedType>(Val: getType());
1665 if (auto floatType = llvm::dyn_cast<FloatType>(Val: inputAttr.getElementType())) {
1666 auto floatVal = inputAttr.getSplatValue<APFloat>();
1667 return DenseElementsAttr::get(type: shapeType,
1668 values: ReciprocalOp::calcOneElement(operand: floatVal));
1669 }
1670
1671 return {};
1672}
1673

source code of mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp