1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Quant/IR/Quant.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/Dialect/Tosa/IR/TosaOps.h"
17#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
18#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20#include "mlir/IR/BuiltinTypeInterfaces.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/DialectImplementation.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Transforms/FoldUtils.h"
26#include "mlir/Transforms/InliningUtils.h"
27#include "mlir/Transforms/RegionUtils.h"
28#include "llvm/ADT/APFloat.h"
29#include "llvm/ADT/APInt.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/TypeSwitch.h"
32
33#include <functional>
34
35using namespace mlir;
36using namespace mlir::tosa;
37
38//===----------------------------------------------------------------------===//
39// Operator Canonicalizers.
40//===----------------------------------------------------------------------===//
41
42//===----------------------------------------------------------------------===//
43// Tensor Data Engine Operators.
44//===----------------------------------------------------------------------===//
45
46// Check that the zero point of the tensor and padding operations are aligned.
47bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
48 // Check that padConst is a constant value and a scalar tensor
49 DenseElementsAttr padConstAttr;
50 if (!matchPattern(value: padConst, pattern: m_Constant(bind_value: &padConstAttr)) ||
51 (padConstAttr.size() != 1)) {
52 return false;
53 }
54
55 // Check that floating point pad is zero
56 if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
57 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
58 return padConstVal == 0.0f;
59 }
60
61 // Check that the zp and padConst align for the integer (quantized) case
62 if (auto padConstIntAttr =
63 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
64 DenseIntElementsAttr zpAttr;
65 // Check that zp is a constant value and a scalar tensor
66 if (!matchPattern(value: zp, pattern: m_Constant(bind_value: &zpAttr)) || (padConstAttr.size() != 1)) {
67 return false;
68 }
69
70 // Check equality
71 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
72 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
73 return zpVal == padConstVal;
74 }
75
76 // Bail-out on unsupported type
77 return false;
78}
79
80namespace {
81template <typename OpTy>
82struct PoolPadFoldAdaptor;
83
84template <>
85struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
86 using OpTy = tosa::AvgPool2dOp;
87 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
88 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
89 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
90 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
91 return false;
92 return true;
93 }
94 static bool checkPadConstCompliance(OpTy op, Value padConst) {
95 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
96 }
97 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
98 Value padInput, ArrayRef<int64_t> newPad) {
99 rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
100 op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
101 op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
102 op.getAccType());
103 }
104};
105
106template <>
107struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
108 using OpTy = tosa::MaxPool2dOp;
109 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
110 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
111 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
112 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
113 return false;
114 return true;
115 }
116 static bool checkPadConstCompliance(OpTy, Value padConst) {
117 // Check that padConst is a constant value and a scalar tensor
118 DenseElementsAttr padConstAttr;
119 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
120 padConstAttr.size() != 1) {
121 return false;
122 }
123
124 // Pad needs to be in the minimum value to be able to merge
125 if (auto padConstFpAttr =
126 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
127 const APFloat padConstVal = *padConstFpAttr.begin();
128 const APFloat lowestVal =
129 APFloat::getLargest(padConstVal.getSemantics(), true);
130 return padConstVal == lowestVal;
131 } else if (auto padConstIntAttr =
132 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
133 const APInt padConstVal = *padConstIntAttr.begin();
134 const unsigned int bitWidth = padConstVal.getBitWidth();
135 const APInt lowestVal =
136 padConstIntAttr.getElementType().isUnsignedInteger()
137 ? APInt::getZero(bitWidth)
138 : APInt::getSignedMinValue(bitWidth);
139 return padConstVal == lowestVal;
140 }
141
142 // Bail-out on unsupported type
143 return false;
144 }
145 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
146 Value padInput, ArrayRef<int64_t> newPad) {
147 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
148 op, op.getType(), padInput, op.getKernel(), op.getStride(),
149 rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
150 }
151};
152
153template <typename OpTy>
154struct ConvPadFoldAdaptor {
155 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
156 return true;
157 }
158 static bool checkPadConstCompliance(OpTy op, Value padConst) {
159 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
160 }
161 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
162 Value padInput, ArrayRef<int64_t> newPad) {
163 rewriter.replaceOpWithNewOp<OpTy>(
164 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
165 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
166 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
167 }
168};
169
170// Pattern attempts to fold a `tosa.pad` operator to a following tensor
171// operation like `tosa.conv2d` by merging the padding associated with the
172// pad operator directly to the implicit padding of the tensor operation.
173// This helps eliminate the explicit padding operator if unused.
174template <typename OpTy, typename AdaptorTy>
175struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
176 using OpRewritePattern<OpTy>::OpRewritePattern;
177
178 LogicalResult matchAndRewrite(OpTy tensorOp,
179 PatternRewriter &rewriter) const override {
180 // Check producer is a tosa::PadOp
181 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
182 if (!padOp)
183 return rewriter.notifyMatchFailure(tensorOp,
184 "Producer must be a tosa::PadOp.");
185
186 // Validate that tensor operation has sane padding
187 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
188 if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
189 return rewriter.notifyMatchFailure(
190 tensorOp, "Tensor operation padding shall have 4 elements.");
191
192 // Validate tosa::PadOp padding
193 DenseIntElementsAttr padOpPadding;
194 if (!matchPattern(padOp.getPadding(), m_Constant(bind_value: &padOpPadding))) {
195 return rewriter.notifyMatchFailure(
196 tensorOp,
197 "The `padding` input specified on the tosa::PadOp must be constant.");
198 }
199 // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
200 // C_after
201 if (padOpPadding.size() != 8)
202 return rewriter.notifyMatchFailure(tensorOp,
203 "Pad padding should have 8 elements.");
204 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
205 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
206 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
207 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
208 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
209 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
210 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
211 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
212
213 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
214 return rewriter.notifyMatchFailure(
215 tensorOp, "Folding padding in N or C dimensions is not supported.");
216
217 // Fold padding from Pad into the tensor operation
218 // 4 elements - pad_top, pad_bottom, pad_left, pad_right
219 SmallVector<int64_t> foldedPad(tensorOpPad.size());
220 foldedPad[0] = padHBefore + tensorOpPad[0];
221 foldedPad[1] = padHAfter + tensorOpPad[1];
222 foldedPad[2] = padWBefore + tensorOpPad[2];
223 foldedPad[3] = padWAfter + tensorOpPad[3];
224
225 // Check kernel related restrictions
226 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
227 return rewriter.notifyMatchFailure(
228 tensorOp, "Padding size not aligned with kernel restrictions.");
229 }
230
231 // Check padding constant restrictions
232 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
233 return rewriter.notifyMatchFailure(
234 tensorOp,
235 "Padding constant is not aligned with operator zero-point.");
236 }
237
238 // Check that padding doesn't grow more than 8K level (8192) for now
239 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
240 return rewriter.notifyMatchFailure(
241 tensorOp, "Padding size more than the 8K level limit.");
242 }
243
244 // Create operator
245 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
246 foldedPad);
247
248 return success();
249 }
250};
251} // namespace
252
253void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
254 MLIRContext *context) {
255 results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
256 PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
257 context);
258}
259
260void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
261 MLIRContext *context) {
262 results.add<
263 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
264 context);
265}
266
267void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
268 MLIRContext *context) {
269 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
270 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
271 context);
272}
273
274struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
275 using OpRewritePattern::OpRewritePattern;
276
277 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
278 PatternRewriter &rewriter) const override {
279 Value input = op.getInput();
280 Value output = op.getOutput();
281 ShapedType inputType = llvm::cast<ShapedType>(input.getType());
282 ShapedType outputType = llvm::cast<ShapedType>(output.getType());
283
284 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
285 return failure();
286 }
287
288 // If the output and input shapes are 1x1, then this is a no op.
289 ArrayRef<int64_t> outputShape = outputType.getShape();
290 if (outputShape[1] != 1 || outputShape[2] != 1) {
291 return failure();
292 }
293
294 ArrayRef<int64_t> inputShape = inputType.getShape();
295 if (inputShape[1] != 1 || inputShape[2] != 1) {
296 return failure();
297 }
298
299 rewriter.replaceOp(op, input);
300 return success();
301 }
302};
303
304void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
305 MLIRContext *context) {
306 results.add<MaxPool2dIsNoOp,
307 FoldPadToTensorOp<tosa::MaxPool2dOp,
308 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
309 context);
310}
311
312//===----------------------------------------------------------------------===//
313// Data Layout / Memory Reinterpretation.
314//===----------------------------------------------------------------------===//
315
316struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
317 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
318
319 LogicalResult matchAndRewrite(tosa::ConcatOp op,
320 PatternRewriter &rewriter) const override {
321 if (op.getInput1().size() != 1)
322 return failure();
323 if (op.getInput1().front().getType() != op.getType()) {
324 rewriter
325 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
326 op.getInput1().front())
327 .getResult();
328 return success();
329 }
330
331 rewriter.replaceOp(op, op.getInput1().front());
332 return success();
333 }
334};
335
336void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
337 MLIRContext *context) {
338 results.add<ConcatOptimization>(context);
339}
340
341LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
342 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
343 if (!notOp)
344 return failure();
345 rewriter.modifyOpInPlace(op, [&]() {
346 op.getOperation()->setOperands(
347 {notOp.getInput1(), op.getInput3(), op.getInput2()});
348 });
349 return success();
350}
351
352struct ConsolidateTransposeOptimization
353 : public OpRewritePattern<tosa::TransposeOp> {
354 using OpRewritePattern::OpRewritePattern;
355
356 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
357 PatternRewriter &rewriter) const override {
358 // Input is also TransposeOp - transpose(transpose(A)).
359 auto innerTranspose =
360 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
361 if (!innerTranspose)
362 return rewriter.notifyMatchFailure(transposeOp,
363 "input must be transpose operation");
364
365 const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
366 const llvm::ArrayRef<int32_t> innerTransposePerms =
367 innerTranspose.getPerms();
368
369 if (transposePerms.size() != innerTransposePerms.size())
370 return rewriter.notifyMatchFailure(
371 transposeOp,
372 "transpose and inner transpose perms sizes must be equal");
373 if (transposePerms.empty())
374 return rewriter.notifyMatchFailure(
375 transposeOp, "transpose perms sizes must be positive");
376
377 // Consolidate transposes into one transpose.
378 SmallVector<int32_t> perms(transposePerms.size());
379 for (int i = 0, s = transposePerms.size(); i < s; ++i)
380 perms[i] = innerTransposePerms[transposePerms[i]];
381
382 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
383 transposeOp, transposeOp.getResult().getType(),
384 innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
385
386 return success();
387 }
388};
389
390// Determines the case when tosa.transpose is a tosa.reshape operation.
391struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
392 using OpRewritePattern::OpRewritePattern;
393
394 LogicalResult matchAndRewrite(tosa::TransposeOp op,
395 PatternRewriter &rewriter) const override {
396 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
397 return rewriter.notifyMatchFailure(
398 op, "Src is from transpose, can compose transposes");
399
400 Value result = op.getResult();
401 for (Operation *subop : result.getUsers()) {
402 if (isa_and_nonnull<tosa::TransposeOp>(subop))
403 return rewriter.notifyMatchFailure(
404 op, "Dest is used by transpose, can compose transposes");
405 }
406
407 auto input = op.getInput1();
408 auto inputTy = llvm::cast<ShapedType>(input.getType());
409 if (!inputTy.hasRank())
410 return rewriter.notifyMatchFailure(op, "Unranked input.");
411
412 int64_t numDynDims = 0;
413 for (int i = 0; i < inputTy.getRank(); ++i)
414 if (inputTy.isDynamicDim(i))
415 numDynDims++;
416
417 if (numDynDims > 1)
418 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
419
420 const llvm::ArrayRef<int32_t> permValues = op.getPerms();
421
422 SmallVector<int64_t> nonZeroPerms;
423 nonZeroPerms.reserve(N: permValues.size());
424 for (auto idx : permValues) {
425 auto sz = inputTy.getDimSize(idx);
426 if (sz != 1)
427 nonZeroPerms.push_back(idx);
428 }
429
430 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
431 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
432 return rewriter.notifyMatchFailure(op,
433 "Transpose changes memory layout.");
434
435 SmallVector<int64_t> newShape;
436 newShape.reserve(N: inputTy.getRank());
437 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
438 newShape.push_back(Elt: inputTy.getDimSize(permValues[i]));
439
440 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
441 op, op.getType(), op.getInput1(),
442 getTosaConstShape(rewriter, op.getLoc(), newShape));
443 return success();
444 }
445};
446
447void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
448 MLIRContext *context) {
449 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
450}
451
452struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
453 using OpRewritePattern::OpRewritePattern;
454
455 LogicalResult matchAndRewrite(tosa::ClampOp op,
456 PatternRewriter &rewriter) const override {
457 Value input = op.getInput();
458 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
459 auto inputElementType = inputType.getElementType();
460
461 if (!inputType.hasStaticShape()) {
462 return failure();
463 }
464
465 if (isa<FloatType>(inputElementType)) {
466 // Unlike integer types, floating point types can represent infinity.
467 auto minClamp =
468 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
469 auto maxClamp =
470 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
471 bool isMin = minClamp.isNegInfinity();
472 bool isMax = maxClamp.isInfinity();
473
474 if (isMin && isMax) {
475 rewriter.replaceOp(op, input);
476 return success();
477 }
478 return failure();
479 }
480
481 if (inputElementType.isUnsignedInteger()) {
482 int64_t minClamp =
483 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
484 int64_t maxClamp =
485 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
486
487 int64_t intMin =
488 APInt::getMinValue(numBits: inputElementType.getIntOrFloatBitWidth())
489 .getZExtValue();
490 int64_t intMax =
491 APInt::getMaxValue(numBits: inputElementType.getIntOrFloatBitWidth())
492 .getZExtValue();
493
494 if (minClamp <= intMin && maxClamp >= intMax) {
495 rewriter.replaceOp(op, input);
496 return success();
497 }
498 return failure();
499 }
500
501 if (llvm::isa<IntegerType>(inputElementType)) {
502 int64_t minClamp =
503 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
504 int64_t maxClamp =
505 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
506
507 int64_t intMin =
508 APInt::getSignedMinValue(numBits: inputElementType.getIntOrFloatBitWidth())
509 .getSExtValue();
510 int64_t intMax =
511 APInt::getSignedMaxValue(numBits: inputElementType.getIntOrFloatBitWidth())
512 .getSExtValue();
513
514 if (minClamp <= intMin && maxClamp >= intMax) {
515 rewriter.replaceOp(op, input);
516 return success();
517 }
518 return failure();
519 }
520
521 return failure();
522 }
523};
524
525// Attempts the following transformation:
526//
527// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
528// tensor X the following identity holds:
529//
530// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
531//
532// subject to the following valid NaN propagation semantics:
533// --------------------------------------------
534// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
535// |-------------|--------------|-------------|
536// | PROPAGATE | PROPAGATE | PROPAGATE |
537// | PROPAGATE | IGNORE | IGNORE |
538// | IGNORE | PROPAGATE | INVALID |
539// | IGNORE | IGNORE | IGNORE |
540// |------------------------------------------|
541
542struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
543 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
544
545 // Helper structure to describe the range of a clamp operation.
546 template <typename T>
547 struct ClampRange {
548 ClampRange(const T &start, const T &end) : start(start), end(end) {}
549 T start;
550 T end;
551
552 // Helper function to determine if two Clamp ranges intersect.
553 bool intersects(const ClampRange<T> &otherRange) {
554 return start < otherRange.end && otherRange.start < end;
555 }
556 };
557
558 LogicalResult matchAndRewrite(tosa::ClampOp op,
559 PatternRewriter &rewriter) const override {
560 Value input = op.getInput();
561
562 // Check the input to the CLAMP op is itself a CLAMP.
563 auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
564 if (!clampOp)
565 return failure();
566
567 // Check we have a valid NaN propagation combination.
568 const auto opNanMode = op.getNanMode();
569 const auto clampNanMode = clampOp.getNanMode();
570 if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
571 return failure();
572
573 auto maxValAttr = op.getMaxValAttr();
574 auto minValAttr = op.getMinValAttr();
575 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
576 auto clampOpMinValAttr = clampOp.getMinValAttr();
577
578 auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
579 if (auto quantType =
580 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
581 inputEType = quantType.getStorageType();
582 }
583
584 Attribute newMinValAttr, newMaxValAttr;
585 if (mlir::isa<FloatType>(inputEType)) {
586 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
587 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
588 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
589 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
590
591 // Check we have intersecting ranges.
592 const auto opMinFloat = floatMinValAttr.getValue();
593 const auto opMaxFloat = floatMaxValAttr.getValue();
594 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
595 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
596 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
597 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
598 clampOpMaxFloat);
599 if (!opRangeFloatRange.intersects(otherRange: clampRangeFloatRange))
600 return failure();
601
602 // Run the transformation.
603 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
604 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
605 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
606 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
607 } else {
608 assert(mlir::isa<IntegerType>(inputEType));
609 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
610 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
611 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
612 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
613
614 if (inputEType.isUnsignedInteger()) {
615 // Check we have intersecting ranges.
616 const auto opMinInt = intMinValAttr.getUInt();
617 const auto opMaxInt = intMaxValAttr.getUInt();
618 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
619 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
620 ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
621 ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
622 clampOpMaxInt);
623 if (!opRangeIntRange.intersects(otherRange: clampRangeIntRange))
624 return failure();
625
626 // Run the transformation.
627 auto newMinVal = std::max(opMinInt, clampOpMinInt);
628 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
629 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
630 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
631 } else {
632 // Check we have intersecting ranges.
633 const auto opMinInt = intMinValAttr.getInt();
634 const auto opMaxInt = intMaxValAttr.getInt();
635 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
636 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
637 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
638 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
639 clampOpMaxInt);
640 if (!opRangeIntRange.intersects(otherRange: clampRangeIntRange))
641 return failure();
642
643 // Run the transformation.
644 auto newMinVal = std::max(opMinInt, clampOpMinInt);
645 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
646 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
647 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
648 }
649 }
650
651 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
652 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
653 rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
654 : opNanMode));
655 return success();
656 }
657};
658
659void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
660 MLIRContext *context) {
661 results.add<ClampIsNoOp>(context);
662 results.add<ClampClampOptimization>(context);
663}
664
665struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
666 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
667
668 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
669 PatternRewriter &rewriter) const override {
670 Value sliceInput = sliceOp.getInput1();
671 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
672 if (!concatOp)
673 return rewriter.notifyMatchFailure(
674 sliceOp, "slice input must be concat operation");
675
676 OperandRange inputs = concatOp.getInput1();
677 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
678 if (!concatType || !concatType.hasStaticShape())
679 return rewriter.notifyMatchFailure(
680 sliceOp, "slice input must be a static ranked tensor");
681 int32_t axis = concatOp.getAxis();
682
683 DenseElementsAttr startElems;
684 DenseElementsAttr sizeElems;
685
686 if (!matchPattern(sliceOp.getStart(), m_Constant(bind_value: &startElems)))
687 return rewriter.notifyMatchFailure(
688 sliceOp, "start of slice must be a static ranked shape");
689
690 if (!matchPattern(sliceOp.getSize(), m_Constant(bind_value: &sizeElems)))
691 return rewriter.notifyMatchFailure(
692 sliceOp, "size of slice must be a static ranked shape");
693
694 llvm::SmallVector<int64_t> sliceStarts =
695 llvm::to_vector(startElems.getValues<int64_t>());
696 llvm::SmallVector<int64_t> sliceSizes =
697 llvm::to_vector(sizeElems.getValues<int64_t>());
698
699 // Validate slice on the concatenated axis. Slicing along this
700 // axis should span only one of the inputs to the concatenate
701 // operation.
702 std::optional<Value> replaceWithSlice;
703 for (auto input : inputs) {
704 auto inputType = dyn_cast<RankedTensorType>(input.getType());
705 if (!inputType || !inputType.hasStaticShape())
706 return rewriter.notifyMatchFailure(
707 sliceOp, "concat input must be a static ranked tensor");
708
709 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
710 inputType.getDimSize(axis)) {
711 auto start_op =
712 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
713 auto size_op =
714 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
715 replaceWithSlice =
716 rewriter
717 .create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
718 input, start_op, size_op)
719 .getResult();
720 break;
721 }
722 sliceStarts[axis] -= inputType.getDimSize(axis);
723 }
724
725 if (!replaceWithSlice)
726 return rewriter.notifyMatchFailure(
727 sliceOp, "corresponding concat input not found for slice");
728
729 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
730 return success();
731 }
732};
733
734struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
735 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
736
737 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
738 PatternRewriter &rewriter) const override {
739 Value sliceInput = sliceOp.getInput1();
740
741 // Check if producer is a PadOp
742 auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
743 if (!padOp)
744 return rewriter.notifyMatchFailure(sliceOp,
745 "slice input must be a pad operation");
746
747 // Check PadOp has a single consumer
748 if (!padOp->hasOneUse())
749 return rewriter.notifyMatchFailure(sliceOp,
750 "pad shall have a single consumer");
751
752 // Check input is statically ranked
753 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
754 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
755 if (!inputTy || !padTy || !inputTy.hasRank())
756 return rewriter.notifyMatchFailure(sliceOp,
757 "slice input must be a ranked tensor");
758
759 // Validate and extract tosa::PadOp padding
760 DenseIntElementsAttr paddingElems;
761 if (!matchPattern(padOp.getPadding(), m_Constant(bind_value: &paddingElems))) {
762 return rewriter.notifyMatchFailure(
763 sliceOp,
764 "`padding` input specified on the tosa::PadOp must be constant.");
765 }
766 llvm::SmallVector<int64_t> padPaddings =
767 llvm::to_vector(paddingElems.getValues<int64_t>());
768
769 // Extract slice parameters
770 DenseElementsAttr startElems;
771 if (!matchPattern(sliceOp.getStart(), m_Constant(bind_value: &startElems)))
772 return rewriter.notifyMatchFailure(
773 sliceOp, "start of slice must be a static ranked shape");
774 llvm::SmallVector<int64_t> sliceStarts =
775 llvm::to_vector(startElems.getValues<int64_t>());
776
777 DenseElementsAttr sizeElems;
778 if (!matchPattern(sliceOp.getSize(), m_Constant(bind_value: &sizeElems)))
779 return rewriter.notifyMatchFailure(
780 sliceOp, "size of slice must be a static ranked shape");
781 llvm::SmallVector<int64_t> sliceSizes =
782 llvm::to_vector(sizeElems.getValues<int64_t>());
783
784 // Check if dynamic dimensions are sliced
785 const int64_t rank = inputTy.getRank();
786 if (llvm::any_of(Range: llvm::seq<int64_t>(Begin: 0, End: rank), P: [&](int64_t i) {
787 const bool isDimDynamic = inputTy.isDynamicDim(i);
788 const bool isDimSliced =
789 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
790
791 return isDimDynamic && isDimSliced;
792 })) {
793 return rewriter.notifyMatchFailure(
794 sliceOp, "axis that are sliced shall be statically known.");
795 }
796
797 // Update the parameters
798 llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
799 llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
800 llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
801 bool updated = false;
802
803 for (int64_t i = 0; i < rank; ++i) {
804 const int64_t padLo = padPaddings[i * 2];
805 const int64_t padHi = padPaddings[i * 2 + 1];
806 const int64_t sliceStart = sliceStarts[i];
807 const int64_t sliceSize = sliceSizes[i];
808 const int64_t sliceEnd = sliceStart + sliceSize;
809
810 // If dimension is dynamic pass-through
811 if (inputTy.isDynamicDim(i)) {
812 newPadPaddings[i * 2] = padLo;
813 newPadPaddings[i * 2 + 1] = padHi;
814 newSliceStarts[i] = sliceStart;
815 continue;
816 }
817
818 // Handle static dimensions
819 const int64_t dimSize = inputTy.getShape()[i];
820 const int64_t dimTotal = padLo + dimSize + padHi;
821
822 // Check slice within bounds
823 if (sliceStart < 0 || sliceEnd > dimTotal)
824 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
825
826 // Compute updated slice start parameter
827 const int64_t newSliceStart = std::max<int64_t>(a: sliceStart - padLo, b: 0);
828 newSliceStarts[i] = newSliceStart;
829 updated |= newSliceStart != sliceStart;
830
831 // Compute updated pad parameters
832 const int64_t newPadLo = std::max<int64_t>(a: padLo - sliceStart, b: 0);
833 const int64_t newPadHi =
834 std::max<int64_t>(a: sliceEnd - (padLo + dimSize), b: 0);
835 newPadPaddings[i * 2] = newPadLo;
836 newPadPaddings[i * 2 + 1] = newPadHi;
837 updated |= (newPadLo != padLo) || (newPadHi != padHi);
838
839 // Calculate new pad output shape
840 newPadShape[i] =
841 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
842 }
843
844 // Check that we actually need to proceed with the rewrite
845 if (!updated)
846 return rewriter.notifyMatchFailure(
847 sliceOp, "terminate condition; nothing to rewrite");
848
849 // Create a PadOp with updated padding
850 auto newPaddingsOp =
851 getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
852 auto newPadTy =
853 RankedTensorType::get(newPadShape, inputTy.getElementType());
854 auto newPadOp = rewriter.create<tosa::PadOp>(
855 padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp,
856 padOp.getPadConst());
857
858 // Update SliceOp and point to new PadOp
859 auto newStartOp =
860 getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
861 rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
862 newPadOp.getResult(), newStartOp,
863 sliceOp.getSize());
864
865 return success();
866 }
867};
868
869// Update size operand of tosa.slice if size has dynamic dims but corresponding
870// output dim is static
871struct SliceDynamicSizeCanonicalization
872 : public OpRewritePattern<tosa::SliceOp> {
873 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
874
875 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
876 PatternRewriter &rewriter) const override {
877 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
878
879 ElementsAttr sizeElems;
880 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
881 return rewriter.notifyMatchFailure(
882 sliceOp, "size of slice must be a static ranked shape");
883 }
884
885 llvm::SmallVector<int64_t> sliceSizes =
886 llvm::to_vector(sizeElems.getValues<int64_t>());
887
888 bool replaceSliceSize{false};
889 // if size op has -1 indicating dynamic shape but corresponding dim on the
890 // output is statically known, update size to match with known output dim
891 // shape
892 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
893 if (size == -1 && !resultType.isDynamicDim(index)) {
894 sliceSizes[index] = resultType.getDimSize(index);
895 replaceSliceSize = true;
896 }
897 }
898
899 if (!replaceSliceSize) {
900 return rewriter.notifyMatchFailure(
901 sliceOp, "no dimension of size of slice is dynamic that resolves "
902 "to static output shape");
903 }
904
905 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
906 auto newSliceOp = rewriter.create<tosa::SliceOp>(
907 sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(),
908 sliceOp.getStart(), size_op);
909
910 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
911 return success();
912 }
913};
914
915void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
916 MLIRContext *context) {
917 results.add<ConcatSliceOptimization, PadSliceOptimization,
918 SliceDynamicSizeCanonicalization>(context);
919}
920
921//===----------------------------------------------------------------------===//
922// Operator Folders.
923//===----------------------------------------------------------------------===//
924
925template <typename IntFolder, typename FloatFolder>
926DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
927 RankedTensorType returnTy) {
928 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
929 auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
930 auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
931 if (lETy != rETy)
932 return {};
933
934 if (llvm::isa<IntegerType>(lETy)) {
935 APInt l = lhs.getSplatValue<APInt>();
936 APInt r = rhs.getSplatValue<APInt>();
937 auto result = IntFolder()(l, r);
938 return DenseElementsAttr::get(returnTy, result);
939 }
940
941 if (llvm::isa<FloatType>(lETy)) {
942 APFloat l = lhs.getSplatValue<APFloat>();
943 APFloat r = rhs.getSplatValue<APFloat>();
944 auto result = FloatFolder()(l, r);
945 return DenseElementsAttr::get(returnTy, result);
946 }
947 }
948
949 return {};
950}
951
952static bool isSplatZero(Type elemType, DenseElementsAttr val) {
953 if (llvm::isa<FloatType>(Val: elemType))
954 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
955 if (llvm::isa<IntegerType>(Val: elemType))
956 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
957 return false;
958}
959
960static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
961 if (llvm::isa<FloatType>(Val: elemType))
962 return val && val.isSplat() &&
963 val.getSplatValue<APFloat>().isExactlyValue(V: 1.0);
964 if (llvm::isa<IntegerType>(Val: elemType)) {
965 const int64_t shifted = 1LL << shift;
966 return val && val.isSplat() &&
967 val.getSplatValue<APInt>().getSExtValue() == shifted;
968 }
969 return false;
970}
971
972OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
973 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
974 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
975 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
976 if (!lhsTy || !rhsTy || !resultTy)
977 return {};
978
979 // Cannot create an ElementsAttr from non-int/float/index types
980 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
981 !rhsTy.getElementType().isIntOrIndexOrFloat())
982 return {};
983
984 auto resultETy = resultTy.getElementType();
985 auto lhsAttr =
986 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
987 auto rhsAttr =
988 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
989
990 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
991 return getInput1();
992 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
993 return getInput2();
994
995 if (!lhsAttr || !rhsAttr)
996 return {};
997
998 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
999 resultTy);
1000}
1001
1002OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1003 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
1004 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1005 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1006 !outputTy.hasStaticShape())
1007 return {};
1008
1009 if (inputTy.getDimSize(getAxis()) == 1)
1010 return DenseElementsAttr::get(outputTy, 0);
1011
1012 return {};
1013}
1014
1015OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1016 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1017 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1018 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1019 if (!lhsTy || !rhsTy || !resultTy)
1020 return {};
1021 if (lhsTy != rhsTy)
1022 return {};
1023
1024 // IntDivOp inputs must be integer type, no need to check for quantized type
1025 auto resultETy = resultTy.getElementType();
1026 auto lhsAttr =
1027 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1028 auto rhsAttr =
1029 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1030 if (lhsAttr && lhsAttr.isSplat()) {
1031 if (llvm::isa<IntegerType>(resultETy) &&
1032 lhsAttr.getSplatValue<APInt>().isZero())
1033 return lhsAttr;
1034 }
1035
1036 if (rhsAttr && rhsAttr.isSplat()) {
1037 if (llvm::isa<IntegerType>(resultETy) &&
1038 rhsAttr.getSplatValue<APInt>().isOne())
1039 return getInput1();
1040 }
1041
1042 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1043 llvm::isa<IntegerType>(resultETy)) {
1044 APInt l = lhsAttr.getSplatValue<APInt>();
1045 APInt r = rhsAttr.getSplatValue<APInt>();
1046 if (!r.isZero()) {
1047 APInt result = l.sdiv(r);
1048 return DenseElementsAttr::get(resultTy, result);
1049 }
1050 }
1051
1052 return {};
1053}
1054
1055namespace {
1056// calculate lhs * rhs >> shift according to TOSA Spec
1057// return nullopt if result is not in range of int32_t when shift > 0
1058std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1059 unsigned bitwidth) {
1060 APInt result = lhs.sext(width: 64) * rhs.sext(width: 64);
1061
1062 if (shift > 0) {
1063 auto round = APInt(64, 1) << (shift - 1);
1064 result += round;
1065 result.ashrInPlace(ShiftAmt: shift);
1066 // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
1067 if (!(result.getSExtValue() >= INT32_MIN &&
1068 result.getSExtValue() <= INT32_MAX)) {
1069 // REQUIRE failed
1070 return std::nullopt;
1071 }
1072 }
1073
1074 return result.trunc(width: bitwidth);
1075}
1076
1077DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1078 RankedTensorType ty, int32_t shift) {
1079 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1080 if (llvm::isa<IntegerType>(ty.getElementType())) {
1081 APInt l = lhs.getSplatValue<APInt>();
1082 APInt r = rhs.getSplatValue<APInt>();
1083
1084 if (shift == 0) {
1085 return DenseElementsAttr::get(ty, l * r);
1086 }
1087
1088 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1089 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1090 if (!result)
1091 return {};
1092 return DenseElementsAttr::get(ty, result.value());
1093 }
1094
1095 if (llvm::isa<FloatType>(ty.getElementType())) {
1096 APFloat l = lhs.getSplatValue<APFloat>();
1097 APFloat r = rhs.getSplatValue<APFloat>();
1098 APFloat result = l * r;
1099 return DenseElementsAttr::get(ty, result);
1100 }
1101 }
1102
1103 return {};
1104}
1105} // namespace
1106
1107OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1108 auto lhs = getInput1();
1109 auto rhs = getInput2();
1110 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1111 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1112 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1113 if (!lhsTy || !rhsTy || !resultTy)
1114 return {};
1115
1116 auto resultETy = resultTy.getElementType();
1117 auto lhsAttr =
1118 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1119 auto rhsAttr =
1120 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1121
1122 // Result right shift on i32_t data type only. For simplification, synthesize
1123 // a zero shift for other data type.
1124 int32_t shift = 0;
1125 if (resultETy.isInteger(32)) {
1126 ElementsAttr shift_elem;
1127 if (getShift().getImpl()) {
1128 if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1129 // cannot be folded when the shift value is unknown.
1130 return {};
1131 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1132 }
1133 }
1134
1135 if (rhsTy == resultTy) {
1136 if (isSplatZero(resultETy, lhsAttr))
1137 return lhsAttr.resizeSplat(resultTy);
1138 if (isSplatOne(resultETy, lhsAttr, shift))
1139 return rhs;
1140 }
1141 if (lhsTy == resultTy) {
1142 if (isSplatZero(resultETy, rhsAttr))
1143 return rhsAttr.resizeSplat(resultTy);
1144 if (isSplatOne(resultETy, rhsAttr, shift))
1145 return lhs;
1146 }
1147
1148 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1149}
1150
1151OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1152 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1153 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1154 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1155 if (!lhsTy || !rhsTy || !resultTy)
1156 return {};
1157
1158 // Cannot create an ElementsAttr from non-int/float/index types
1159 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1160 !rhsTy.getElementType().isIntOrIndexOrFloat())
1161 return {};
1162
1163 auto resultETy = resultTy.getElementType();
1164 auto lhsAttr =
1165 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1166 auto rhsAttr =
1167 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1168
1169 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1170 return getInput1();
1171
1172 if (!lhsAttr || !rhsAttr)
1173 return {};
1174
1175 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
1176 resultTy);
1177}
1178
1179namespace {
1180template <typename Cmp>
1181struct ComparisonFold {
1182 ComparisonFold() = default;
1183 APInt operator()(const APInt &l, const APInt &r) {
1184 return APInt(1, Cmp()(l, r));
1185 }
1186
1187 APInt operator()(const APFloat &l, const APFloat &r) {
1188 return APInt(1, Cmp()(l, r));
1189 }
1190};
1191
1192struct APIntFoldGreater {
1193 APIntFoldGreater() = default;
1194 APInt operator()(const APInt &l, const APInt &r) {
1195 return APInt(1, l.sgt(RHS: r));
1196 }
1197};
1198
1199struct APIntFoldGreaterEqual {
1200 APIntFoldGreaterEqual() = default;
1201 APInt operator()(const APInt &l, const APInt &r) {
1202 return APInt(1, l.sge(RHS: r));
1203 }
1204};
1205} // namespace
1206
1207OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1208 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1209 auto lhsAttr =
1210 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1211 auto rhsAttr =
1212 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1213
1214 if (!lhsAttr || !rhsAttr)
1215 return {};
1216
1217 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
1218 lhsAttr, rhsAttr, resultTy);
1219}
1220
1221OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1222 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1223 auto lhsAttr =
1224 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1225 auto rhsAttr =
1226 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1227
1228 if (!lhsAttr || !rhsAttr)
1229 return {};
1230
1231 return binaryFolder<APIntFoldGreaterEqual,
1232 ComparisonFold<std::greater_equal<APFloat>>>(
1233 lhsAttr, rhsAttr, resultTy);
1234}
1235
1236OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1237 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1238 auto lhsAttr =
1239 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1240 auto rhsAttr =
1241 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1242 Value lhs = getInput1();
1243 Value rhs = getInput2();
1244 auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1245
1246 // If we are comparing an integer value to itself it is always true. We can
1247 // not do this with float due to float values.
1248 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1249 resultTy.hasStaticShape() && lhs == rhs) {
1250 return DenseElementsAttr::get(resultTy, true);
1251 }
1252
1253 if (!lhsAttr || !rhsAttr)
1254 return {};
1255
1256 return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
1257 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
1258 resultTy);
1259}
1260
1261OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1262 if (getInput().getType() == getType())
1263 return getInput();
1264
1265 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1266 if (!operand)
1267 return {};
1268
1269 auto inTy = llvm::cast<ShapedType>(getInput().getType());
1270 auto outTy = llvm::cast<ShapedType>(getType());
1271 auto inETy = inTy.getElementType();
1272 auto outETy = outTy.getElementType();
1273
1274 if (operand.isSplat()) {
1275 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1276 bool overflow;
1277 auto splatVal = operand.getSplatValue<APFloat>();
1278 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1279 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1280 &overflow);
1281 return SplatElementsAttr::get(outTy, splatVal);
1282 }
1283
1284 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1285 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1286 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1287 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1288 llvm::RoundingMode::NearestTiesToEven);
1289 return SplatElementsAttr::get(outTy, splatVal);
1290 }
1291
1292 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1293 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1294 auto intVal = APSInt(
1295 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1296 auto floatVal = operand.getSplatValue<APFloat>();
1297 bool exact;
1298 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1299 &exact);
1300 return SplatElementsAttr::get(outTy, intVal);
1301 }
1302
1303 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1304 auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1305 bool trunc =
1306 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1307 auto intVal = operand.getSplatValue<APInt>();
1308 auto bitwidth = outETy.getIntOrFloatBitWidth();
1309
1310 if (trunc) {
1311 intVal = intVal.trunc(bitwidth);
1312 } else if (unsignIn) {
1313 intVal = intVal.zext(bitwidth);
1314 } else {
1315 intVal = intVal.sext(bitwidth);
1316 }
1317
1318 return SplatElementsAttr::get(outTy, intVal);
1319 }
1320 }
1321
1322 return {};
1323}
1324
1325OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1326
1327OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1328
1329#define REDUCE_FOLDER(OP) \
1330 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1331 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1332 if (!inputTy.hasRank()) \
1333 return {}; \
1334 if (inputTy != getType()) \
1335 return {}; \
1336 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1337 return getInput(); \
1338 return {}; \
1339 }
1340
1341REDUCE_FOLDER(ReduceAllOp)
1342REDUCE_FOLDER(ReduceAnyOp)
1343REDUCE_FOLDER(ReduceMaxOp)
1344REDUCE_FOLDER(ReduceMinOp)
1345REDUCE_FOLDER(ReduceProductOp)
1346REDUCE_FOLDER(ReduceSumOp)
1347#undef REDUCE_FOLDER
1348
1349OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1350 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1351 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1352
1353 if (!inputTy || !outputTy)
1354 return {};
1355
1356 // Fold when the input and output types are the same. This is only safe when
1357 // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
1358 // there may still be a productive reshape.
1359 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1360 return getInput1();
1361
1362 // reshape(reshape(x)) -> reshape(x)
1363 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1364 getInput1().getDefiningOp())) {
1365 getInput1Mutable().assign(reshapeOp.getInput1());
1366 return getResult();
1367 }
1368
1369 // Cannot create an ElementsAttr from non-int/float/index types
1370 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1371 return {};
1372
1373 // reshape(const(x)) -> const(reshape-attr(x))
1374 if (auto operand =
1375 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1376 // Constants must have static shape.
1377 if (!outputTy.hasStaticShape())
1378 return {};
1379
1380 // Okay to duplicate splat constants.
1381 if (operand.isSplat())
1382 return SplatElementsAttr::get(outputTy,
1383 operand.getSplatValue<Attribute>());
1384
1385 // Don't duplicate other constants.
1386 if (!getInput1().hasOneUse())
1387 return {};
1388
1389 llvm::SmallVector<int64_t> shapeVec;
1390 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1391 return {};
1392
1393 return operand.reshape(
1394 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1395 }
1396
1397 return {};
1398}
1399
1400OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1401 // If the pad is all zeros we can fold this operation away.
1402 if (adaptor.getPadding() && getInput1().getType() == getType()) {
1403 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1404 if (densePad && densePad.isSplat() &&
1405 densePad.getSplatValue<APInt>().isZero()) {
1406 return getInput1();
1407 }
1408 }
1409
1410 return {};
1411}
1412
1413// Fold away cases where a tosa.resize operation returns a copy
1414// of the input image.
1415OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1416 auto scaleAttr =
1417 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1418 auto offsetAttr =
1419 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1420 auto borderAttr =
1421 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1422 if (!scaleAttr || !offsetAttr || !borderAttr) {
1423 return {};
1424 }
1425
1426 auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1427 auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1428 auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1429 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1430 return {};
1431 }
1432
1433 // Check unit scaling.
1434 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1435 return {};
1436 }
1437
1438 // There should be no offset.
1439 if (offset[0] != 0 || offset[1] != 0) {
1440 return {};
1441 }
1442
1443 // There should be no border.
1444 if (border[0] != 0 || border[1] != 0) {
1445 return {};
1446 }
1447
1448 auto input = getInput();
1449 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1450 auto resultTy = llvm::cast<RankedTensorType>(getType());
1451 if (inputTy != resultTy)
1452 return {};
1453
1454 return input;
1455}
1456
1457OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1458 auto operand = getInput1();
1459 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1460 auto axis = getAxis();
1461 auto operandAttr =
1462 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1463 if (operandAttr)
1464 return operandAttr;
1465
1466 // If the dim-length is 1, tosa.reverse is a no-op.
1467 if (operandTy.hasRank() &&
1468 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1469 return operand;
1470
1471 return {};
1472}
1473
1474OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1475 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1476 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1477
1478 if (!inputTy || !outputTy)
1479 return {};
1480
1481 if (inputTy == outputTy && inputTy.hasStaticShape())
1482 return getInput1();
1483
1484 if (!adaptor.getInput1())
1485 return {};
1486
1487 // Cannot create an ElementsAttr from non-int/float/index types
1488 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1489 !outputTy.getElementType().isIntOrIndexOrFloat())
1490 return {};
1491
1492 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1493 if (operand.isSplat() && outputTy.hasStaticShape()) {
1494 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1495 }
1496
1497 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1498 outputTy.getNumElements() == 1) {
1499 DenseElementsAttr startElems;
1500 if (!matchPattern(getStart(), m_Constant(&startElems)))
1501 return {};
1502
1503 llvm::SmallVector<uint64_t> indices =
1504 llvm::to_vector(startElems.getValues<uint64_t>());
1505 auto value = operand.getValues<Attribute>()[indices];
1506 return SplatElementsAttr::get(outputTy, value);
1507 }
1508
1509 return {};
1510}
1511
1512OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1513 if (getInput2() == getInput3())
1514 return getInput2();
1515
1516 auto predicate =
1517 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1518 if (!predicate)
1519 return {};
1520
1521 if (!predicate.isSplat())
1522 return {};
1523 return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
1524 : getInput3();
1525}
1526
1527OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1528 if (getInput1().getType() == getType()) {
1529 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1530 adaptor.getMultiples())) {
1531 if (multiples.isSplat() &&
1532 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1533 return getInput1();
1534 if (auto int_array_attr =
1535 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1536 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1537 [](APInt v) { return v.getSExtValue() == 1; }))
1538 return getInput1();
1539 }
1540 }
1541 }
1542 return {};
1543}
1544
1545OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1546 auto resultTy = llvm::cast<ShapedType>(getType());
1547
1548 // Transposing splat values just means reshaping.
1549 if (auto input =
1550 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1551 if (input.isSplat() && resultTy.hasStaticShape() &&
1552 input.getType().getElementType() == resultTy.getElementType())
1553 return input.reshape(resultTy);
1554 }
1555
1556 // Transpose is not the identity transpose.
1557 const llvm::ArrayRef<int32_t> perms = getPerms();
1558
1559 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1560 return {};
1561
1562 return getInput1();
1563}
1564
1565OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
1566 auto input = getInput1();
1567 // Element-wise log(exp(x)) = x
1568 if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
1569 return op.getInput1();
1570 }
1571
1572 return {};
1573}
1574
1575OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
1576 auto input = getInput1();
1577 // Element-wise exp(log(x)) = x
1578 if (auto op = input.getDefiningOp<tosa::LogOp>()) {
1579 return op.getInput1();
1580 }
1581
1582 return {};
1583}
1584
1585OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1586 // Element-wise negate(negate(x)) = x
1587 // iff all zero points are constant 0
1588 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1589 if (!definingOp) {
1590 // defining op of input1 is not a negate, cannot fold
1591 return {};
1592 }
1593
1594 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1595 failed(maybeIZp) || *maybeIZp != 0) {
1596 // input1 zero point is not constant 0, cannot fold
1597 return {};
1598 }
1599 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1600 failed(maybeOZp) || *maybeOZp != 0) {
1601 // output zero point is not constant 0, cannot fold
1602 return {};
1603 }
1604 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1605 failed(maybeIZp) || *maybeIZp != 0) {
1606 // definingOp's input1 zero point is not constant 0, cannot fold
1607 return {};
1608 }
1609 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1610 failed(maybeOZp) || *maybeOZp != 0) {
1611 // definingOp's output zero point is not constant 0, cannot fold
1612 return {};
1613 }
1614
1615 return definingOp.getInput1();
1616}
1617
1618OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1619 auto input = getInput1();
1620 // Element-wise abs(abs(x)) = abs(x)
1621 if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1622 return input;
1623 }
1624
1625 return {};
1626}
1627
1628OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1629 // Fold consecutive concats on the same axis into a single op.
1630 // Keep track of the operands so we are able to construct a new concat
1631 // later. Conservatively assume that we double the number of operands when
1632 // folding
1633 SmallVector<Value, 8> concatOperands;
1634 concatOperands.reserve(2 * getNumOperands());
1635
1636 // Find all operands that are foldable concats
1637 bool foundFoldableConcat = false;
1638 for (Value operand : getOperands()) {
1639 concatOperands.emplace_back(operand);
1640
1641 auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1642 if (!producer)
1643 continue;
1644
1645 // Not foldable if axes are not the same
1646 if (getAxis() != producer.getAxis())
1647 continue;
1648
1649 // Replace the original operand with all incoming operands
1650 foundFoldableConcat = true;
1651 concatOperands.pop_back();
1652 llvm::append_range(concatOperands, producer->getOperands());
1653 }
1654
1655 if (!foundFoldableConcat)
1656 return {};
1657
1658 getOperation()->setOperands(concatOperands);
1659 return getResult();
1660}
1661
1662OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1663 auto input = adaptor.getInput1();
1664
1665 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1666 // Fold splat inputs only.
1667 if (!inputAttr || !inputAttr.isSplat())
1668 return {};
1669
1670 auto shapeType = llvm::cast<ShapedType>(getType());
1671 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1672 auto floatVal = inputAttr.getSplatValue<APFloat>();
1673 return DenseElementsAttr::get(shapeType,
1674 ReciprocalOp::calcOneElement(floatVal));
1675 }
1676
1677 return {};
1678}
1679

source code of mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp