1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Quant/QuantOps.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/Dialect/Tosa/IR/TosaOps.h"
17#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
18#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20#include "mlir/IR/BuiltinTypeInterfaces.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/DialectImplementation.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Transforms/FoldUtils.h"
26#include "mlir/Transforms/InliningUtils.h"
27#include "mlir/Transforms/RegionUtils.h"
28#include "llvm/ADT/APFloat.h"
29#include "llvm/ADT/APInt.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/TypeSwitch.h"
32
33#include <functional>
34
35using namespace mlir;
36using namespace mlir::tosa;
37
38//===----------------------------------------------------------------------===//
39// Operator Canonicalizers.
40//===----------------------------------------------------------------------===//
41
42struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
43 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
44
45 LogicalResult matchAndRewrite(tosa::ConcatOp op,
46 PatternRewriter &rewriter) const override {
47 if (op.getInput1().size() != 1)
48 return failure();
49 if (op.getInput1().front().getType() != op.getType()) {
50 rewriter
51 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
52 op.getInput1().front())
53 .getResult();
54 return success();
55 }
56
57 rewriter.replaceOp(op, op.getInput1().front());
58 return success();
59 }
60};
61
62void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
63 MLIRContext *context) {
64 results.add<ConcatOptimization>(context);
65}
66
67LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
68 auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
69 if (!notOp)
70 return failure();
71 rewriter.modifyOpInPlace(op, [&]() {
72 op.getOperation()->setOperands(
73 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
74 });
75 return success();
76}
77
78struct ConsolidateTransposeOptimization
79 : public OpRewritePattern<tosa::TransposeOp> {
80 using OpRewritePattern::OpRewritePattern;
81
82 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
83 PatternRewriter &rewriter) const override {
84 // Input is also TransposeOp - transpose(transpose(A)).
85 auto innerTranspose =
86 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
87 if (!innerTranspose)
88 return rewriter.notifyMatchFailure(transposeOp,
89 "input must be transpose operation");
90
91 SmallVector<int64_t> transposePerms, innerTransposePerms;
92 if (transposeOp.getConstantPerms(transposePerms).failed())
93 return rewriter.notifyMatchFailure(transposeOp,
94 "transpose perms must be constant");
95 if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
96 return rewriter.notifyMatchFailure(
97 transposeOp, "inner transpose perms must be constant");
98 if (transposePerms.size() != innerTransposePerms.size())
99 return rewriter.notifyMatchFailure(
100 transposeOp,
101 "transpose and inner transpose perms sizes must be equal");
102 if (transposePerms.empty())
103 return rewriter.notifyMatchFailure(
104 transposeOp, "transpose perms sizes must be positive");
105
106 // Consolidate transposes into one transpose.
107 SmallVector<int32_t> perms(transposePerms.size());
108 for (int i = 0, s = transposePerms.size(); i < s; ++i)
109 perms[i] = innerTransposePerms[transposePerms[i]];
110
111 auto permsTy =
112 RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
113 auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
114 Value permsValue =
115 rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
116
117 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
118 transposeOp, transposeOp.getResult().getType(),
119 innerTranspose.getInput1(), permsValue);
120
121 return success();
122 }
123};
124
125// Determines the case when tosa.transpose is a tosa.reshape operation.
126struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
127 using OpRewritePattern::OpRewritePattern;
128
129 LogicalResult matchAndRewrite(tosa::TransposeOp op,
130 PatternRewriter &rewriter) const override {
131 DenseIntElementsAttr permAttr;
132 if (!matchPattern(op.getPerms(), m_Constant(bind_value: &permAttr)))
133 return rewriter.notifyMatchFailure(op, "Non-constant permutation");
134
135 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
136 return rewriter.notifyMatchFailure(
137 op, "Src is from transpose, can compose transposes");
138
139 Value result = op.getResult();
140 for (Operation *subop : result.getUsers()) {
141 if (dyn_cast_or_null<tosa::TransposeOp>(subop))
142 return rewriter.notifyMatchFailure(
143 op, "Dest is used by transpose, can compose transposes");
144 }
145
146 auto input = op.getInput1();
147 auto inputTy = llvm::cast<ShapedType>(input.getType());
148 if (!inputTy.hasRank())
149 return rewriter.notifyMatchFailure(op, "Unranked input.");
150
151 int64_t numDynDims = 0;
152 for (int i = 0; i < inputTy.getRank(); ++i)
153 if (inputTy.isDynamicDim(i))
154 numDynDims++;
155
156 if (numDynDims > 1)
157 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
158
159 SmallVector<int64_t> permValues = llvm::to_vector<6>(
160 llvm::map_range(permAttr.getValues<APInt>(),
161 [](const APInt &val) { return val.getSExtValue(); }));
162
163 SmallVector<int64_t> nonZeroPerms;
164 nonZeroPerms.reserve(N: permValues.size());
165 for (auto idx : permValues) {
166 auto sz = inputTy.getDimSize(idx);
167 if (sz != 1)
168 nonZeroPerms.push_back(Elt: idx);
169 }
170
171 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
172 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
173 return rewriter.notifyMatchFailure(op,
174 "Transpose changes memory layout.");
175
176 SmallVector<int64_t> newShape;
177 newShape.reserve(N: inputTy.getRank());
178 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
179 newShape.push_back(Elt: inputTy.getDimSize(permValues[i]));
180
181 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
182 op, op.getType(), op.getInput1(),
183 rewriter.getDenseI64ArrayAttr(newShape));
184 return success();
185 }
186};
187
188void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
189 MLIRContext *context) {
190 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
191}
192
193struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
194 using OpRewritePattern::OpRewritePattern;
195
196 LogicalResult matchAndRewrite(tosa::PadOp op,
197 PatternRewriter &rewriter) const override {
198 if (op.getPadConst())
199 return failure();
200
201 auto input = op.getInput1();
202 auto padding = op.getPadding();
203
204 ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
205 Type elementTy = inputTy.getElementType();
206
207 Attribute constantAttr;
208 if (llvm::isa<FloatType>(Val: elementTy)) {
209 constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
210 } else if (llvm::isa<IntegerType>(Val: elementTy) && !op.getQuantizationInfo()) {
211 constantAttr = rewriter.getIntegerAttr(elementTy, 0);
212 } else if (llvm::isa<IntegerType>(Val: elementTy) && op.getQuantizationInfo()) {
213 auto value = op.getQuantizationInfo()->getInputZp();
214 constantAttr = rewriter.getIntegerAttr(elementTy, value);
215 }
216
217 if (!constantAttr) {
218 return rewriter.notifyMatchFailure(
219 op,
220 "tosa.pad to linalg lowering encountered an unknown element type");
221 }
222
223 auto denseAttr = DenseElementsAttr::get(
224 RankedTensorType::get({}, elementTy), constantAttr);
225 auto constantVal = rewriter.create<tosa::ConstOp>(
226 op.getLoc(), denseAttr.getType(), denseAttr);
227
228 rewriter.replaceOpWithNewOp<tosa::PadOp>(
229 op, op.getType(), ValueRange{input, padding, constantVal},
230 op->getAttrs());
231 return success();
232 }
233};
234
235void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
236 MLIRContext *context) {
237 results.add<MaterializePadValue>(context);
238}
239
240struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
241 using OpRewritePattern::OpRewritePattern;
242
243 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
244 PatternRewriter &rewriter) const override {
245 Value input = op.getInput();
246 Value output = op.getOutput();
247 ShapedType inputType = llvm::cast<ShapedType>(input.getType());
248 ShapedType outputType = llvm::cast<ShapedType>(output.getType());
249
250 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
251 return failure();
252 }
253
254 // If the output and input shapes are 1x1, then this is a no op.
255 ArrayRef<int64_t> outputShape = outputType.getShape();
256 if (outputShape[1] != 1 || outputShape[2] != 1) {
257 return failure();
258 }
259
260 ArrayRef<int64_t> inputShape = inputType.getShape();
261 if (inputShape[1] != 1 || inputShape[2] != 1) {
262 return failure();
263 }
264
265 rewriter.replaceOp(op, input);
266 return success();
267 }
268};
269
270void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
271 MLIRContext *context) {
272 results.add<MaxPool2dIsNoOp>(context);
273}
274
275struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
276 using OpRewritePattern::OpRewritePattern;
277
278 LogicalResult matchAndRewrite(tosa::ClampOp op,
279 PatternRewriter &rewriter) const override {
280 Value input = op.getInput();
281 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
282 auto inputElementType = inputType.getElementType();
283
284 if (!inputType.hasStaticShape()) {
285 return failure();
286 }
287
288 if (isa<FloatType>(inputElementType)) {
289 // Unlike integer types, floating point types can represent infinity.
290 auto minClamp = op.getMinFp();
291 auto maxClamp = op.getMaxFp();
292 bool isMin = minClamp.isInfinity() && minClamp.isNegative();
293 bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
294
295 if (isMin && isMax) {
296 rewriter.replaceOp(op, input);
297 return success();
298 }
299 return failure();
300 }
301
302 if (inputElementType.isUnsignedInteger()) {
303 int64_t minClamp = op.getMinInt();
304 int64_t maxClamp = op.getMaxInt();
305
306 int64_t intMin =
307 APInt::getMinValue(numBits: inputElementType.getIntOrFloatBitWidth())
308 .getZExtValue();
309 int64_t intMax =
310 APInt::getMaxValue(numBits: inputElementType.getIntOrFloatBitWidth())
311 .getZExtValue();
312
313 if (minClamp <= intMin && maxClamp >= intMax) {
314 rewriter.replaceOp(op, input);
315 return success();
316 }
317 return failure();
318 }
319
320 if (llvm::isa<IntegerType>(inputElementType)) {
321 int64_t minClamp = op.getMinInt();
322 int64_t maxClamp = op.getMaxInt();
323
324 int64_t intMin =
325 APInt::getSignedMinValue(numBits: inputElementType.getIntOrFloatBitWidth())
326 .getSExtValue();
327 int64_t intMax =
328 APInt::getSignedMaxValue(numBits: inputElementType.getIntOrFloatBitWidth())
329 .getSExtValue();
330
331 if (minClamp <= intMin && maxClamp >= intMax) {
332 rewriter.replaceOp(op, input);
333 return success();
334 }
335 return failure();
336 }
337
338 return failure();
339 }
340};
341
342struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
343 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
344
345 LogicalResult matchAndRewrite(tosa::ClampOp op,
346 PatternRewriter &rewriter) const override {
347 Value input = op.getInput();
348
349 Operation *definingOp = input.getDefiningOp();
350 if (!definingOp)
351 return failure();
352
353 if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
354 auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
355 auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
356
357 auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
358 auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
359
360 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
361 op, op.getType(), clampOp.getInput(),
362 rewriter.getI64IntegerAttr(minInt),
363 rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
364 rewriter.getF32FloatAttr(maxFp));
365 return success();
366 }
367
368 return failure();
369 }
370};
371
372void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
373 MLIRContext *context) {
374 results.add<ClampIsNoOp>(context);
375 results.add<ClampClampOptimization>(context);
376}
377
378struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
379 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
380
381 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
382 PatternRewriter &rewriter) const override {
383 Value sliceInput = sliceOp.getInput();
384 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
385 if (!concatOp)
386 return rewriter.notifyMatchFailure(
387 sliceOp, "slice input must be concat operation");
388
389 OperandRange inputs = concatOp.getInput1();
390 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
391 if (!concatType || !concatType.hasStaticShape())
392 return rewriter.notifyMatchFailure(
393 sliceOp, "slice input must be a static ranked tensor");
394 int32_t axis = concatOp.getAxis();
395
396 llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
397 llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
398
399 // Validate slice on the concatenated axis. Slicing along this
400 // axis should span only one of the inputs to the concatenate
401 // operation.
402 std::optional<Value> replaceWithSlice;
403 for (auto input : inputs) {
404 auto inputType = dyn_cast<RankedTensorType>(input.getType());
405 if (!inputType || !inputType.hasStaticShape())
406 return rewriter.notifyMatchFailure(
407 sliceOp, "concat input must be a static ranked tensor");
408
409 if (sliceStart[axis] >= 0 &&
410 (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
411 replaceWithSlice = rewriter
412 .create<tosa::SliceOp>(
413 sliceOp.getLoc(), sliceOp.getType(), input,
414 rewriter.getDenseI64ArrayAttr(sliceStart),
415 rewriter.getDenseI64ArrayAttr(sliceSize))
416 .getResult();
417 break;
418 }
419 sliceStart[axis] -= inputType.getDimSize(axis);
420 }
421
422 if (!replaceWithSlice)
423 return rewriter.notifyMatchFailure(
424 sliceOp, "corresponding concat input not found for slice");
425
426 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
427 return success();
428 }
429};
430
431void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
432 MLIRContext *context) {
433 results.add<ConcatSliceOptimization>(context);
434}
435
436//===----------------------------------------------------------------------===//
437// Operator Folders.
438//===----------------------------------------------------------------------===//
439
440template <typename IntFolder, typename FloatFolder>
441DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
442 RankedTensorType returnTy) {
443 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
444 auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
445 auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
446 if (lETy != rETy)
447 return {};
448
449 if (llvm::isa<IntegerType>(lETy)) {
450 APInt l = lhs.getSplatValue<APInt>();
451 APInt r = rhs.getSplatValue<APInt>();
452 auto result = IntFolder()(l, r);
453 return DenseElementsAttr::get(returnTy, result);
454 }
455
456 if (llvm::isa<FloatType>(lETy)) {
457 APFloat l = lhs.getSplatValue<APFloat>();
458 APFloat r = rhs.getSplatValue<APFloat>();
459 auto result = FloatFolder()(l, r);
460 return DenseElementsAttr::get(returnTy, result);
461 }
462 }
463
464 return {};
465}
466
467static bool isSplatZero(Type elemType, DenseElementsAttr val) {
468 if (llvm::isa<FloatType>(Val: elemType))
469 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
470 if (llvm::isa<IntegerType>(Val: elemType))
471 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
472 return false;
473}
474
475static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
476 if (llvm::isa<FloatType>(Val: elemType))
477 return val && val.isSplat() &&
478 val.getSplatValue<APFloat>().isExactlyValue(V: 1.0);
479 if (llvm::isa<IntegerType>(Val: elemType)) {
480 const int64_t shifted = 1LL << shift;
481 return val && val.isSplat() &&
482 val.getSplatValue<APInt>().getSExtValue() == shifted;
483 }
484 return false;
485}
486
487OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
488 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
489 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
490 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
491 if (!lhsTy || !rhsTy || !resultTy)
492 return {};
493
494 auto resultETy = resultTy.getElementType();
495 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
496 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
497
498 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
499 return getInput1();
500 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
501 return getInput2();
502
503 if (!lhsAttr || !rhsAttr)
504 return {};
505
506 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
507 resultTy);
508}
509
510OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
511 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
512 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
513 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
514 !outputTy.hasStaticShape())
515 return {};
516
517 if (inputTy.getDimSize(getAxis()) == 1)
518 return DenseElementsAttr::get(outputTy, 0);
519
520 return {};
521}
522
523OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
524 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
525 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
526 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
527 if (!lhsTy || !rhsTy || !resultTy)
528 return {};
529 if (lhsTy != rhsTy)
530 return {};
531
532 auto resultETy = resultTy.getElementType();
533 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
534 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
535 if (lhsAttr && lhsAttr.isSplat()) {
536 if (llvm::isa<IntegerType>(resultETy) &&
537 lhsAttr.getSplatValue<APInt>().isZero())
538 return lhsAttr;
539 }
540
541 if (rhsAttr && rhsAttr.isSplat()) {
542 if (llvm::isa<IntegerType>(resultETy) &&
543 rhsAttr.getSplatValue<APInt>().isOne())
544 return getInput1();
545 }
546
547 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
548 if (llvm::isa<IntegerType>(resultETy)) {
549 APInt l = lhsAttr.getSplatValue<APInt>();
550 APInt r = rhsAttr.getSplatValue<APInt>();
551 APInt result = l.sdiv(r);
552 return DenseElementsAttr::get(resultTy, result);
553 }
554 }
555
556 return {};
557}
558
559namespace {
560DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
561 RankedTensorType ty, int32_t shift) {
562 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
563 if (llvm::isa<IntegerType>(ty.getElementType())) {
564 APInt l = lhs.getSplatValue<APInt>();
565 APInt r = rhs.getSplatValue<APInt>();
566
567 if (shift == 0) {
568 return DenseElementsAttr::get(ty, l * r);
569 }
570
571 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
572 l = l.sext(width: bitwidth * 2);
573 r = r.sext(width: bitwidth * 2);
574 auto result = l * r;
575 result.lshrInPlace(ShiftAmt: shift);
576 result = result.trunc(width: bitwidth);
577 return DenseElementsAttr::get(ty, result);
578 }
579
580 if (llvm::isa<FloatType>(ty.getElementType())) {
581 APFloat l = lhs.getSplatValue<APFloat>();
582 APFloat r = rhs.getSplatValue<APFloat>();
583 APFloat result = l * r;
584 return DenseElementsAttr::get(ty, result);
585 }
586 }
587
588 return {};
589}
590} // namespace
591
592OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
593 auto lhs = getInput1();
594 auto rhs = getInput2();
595 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
596 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
597 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
598 if (!lhsTy || !rhsTy || !resultTy)
599 return {};
600
601 auto resultETy = resultTy.getElementType();
602 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
603 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
604
605 const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
606 if (rhsTy == resultTy) {
607 if (isSplatZero(resultETy, lhsAttr))
608 return lhsAttr.resizeSplat(resultTy);
609 if (isSplatOne(resultETy, lhsAttr, shift))
610 return rhs;
611 }
612 if (lhsTy == resultTy) {
613 if (isSplatZero(resultETy, rhsAttr))
614 return rhsAttr.resizeSplat(resultTy);
615 if (isSplatOne(resultETy, rhsAttr, shift))
616 return lhs;
617 }
618
619 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
620}
621
622OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
623 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
624 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
625 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
626 if (!lhsTy || !rhsTy || !resultTy)
627 return {};
628
629 auto resultETy = resultTy.getElementType();
630 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
631 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
632
633 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
634 return getInput1();
635
636 if (!lhsAttr || !rhsAttr)
637 return {};
638
639 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
640 resultTy);
641}
642
643namespace {
644template <typename Cmp>
645struct ComparisonFold {
646 ComparisonFold() = default;
647 APInt operator()(const APInt &l, const APInt &r) {
648 return APInt(1, Cmp()(l, r));
649 }
650
651 APInt operator()(const APFloat &l, const APFloat &r) {
652 return APInt(1, Cmp()(l, r));
653 }
654};
655
656struct APIntFoldGreater {
657 APIntFoldGreater() = default;
658 APInt operator()(const APInt &l, const APInt &r) {
659 return APInt(1, l.sgt(RHS: r));
660 }
661};
662
663struct APIntFoldGreaterEqual {
664 APIntFoldGreaterEqual() = default;
665 APInt operator()(const APInt &l, const APInt &r) {
666 return APInt(1, l.sge(RHS: r));
667 }
668};
669} // namespace
670
671OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
672 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
673 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
674 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
675
676 if (!lhsAttr || !rhsAttr)
677 return {};
678
679 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
680 lhsAttr, rhsAttr, resultTy);
681}
682
683OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
684 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
685 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
686 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
687
688 if (!lhsAttr || !rhsAttr)
689 return {};
690
691 return binaryFolder<APIntFoldGreaterEqual,
692 ComparisonFold<std::greater_equal<APFloat>>>(
693 lhsAttr, rhsAttr, resultTy);
694}
695
696OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
697 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
698 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
699 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
700 Value lhs = getInput1();
701 Value rhs = getInput2();
702 auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
703
704 // If we are comparing an integer value to itself it is always true. We can
705 // not do this with float due to float values.
706 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
707 resultTy.hasStaticShape() && lhs == rhs) {
708 return DenseElementsAttr::get(resultTy, true);
709 }
710
711 if (!lhsAttr || !rhsAttr)
712 return {};
713
714 return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
715 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
716 resultTy);
717}
718
719OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
720 if (getInput().getType() == getType())
721 return getInput();
722
723 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
724 if (!operand)
725 return {};
726
727 auto inTy = llvm::cast<ShapedType>(getInput().getType());
728 auto outTy = llvm::cast<ShapedType>(getType());
729 auto inETy = inTy.getElementType();
730 auto outETy = outTy.getElementType();
731
732 if (operand.isSplat()) {
733 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
734 bool overflow;
735 auto splatVal = operand.getSplatValue<APFloat>();
736 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
737 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
738 &overflow);
739 return SplatElementsAttr::get(outTy, splatVal);
740 }
741
742 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
743 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
744 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
745 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
746 llvm::RoundingMode::NearestTiesToEven);
747 return SplatElementsAttr::get(outTy, splatVal);
748 }
749
750 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
751 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
752 auto intVal = APSInt(
753 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
754 auto floatVal = operand.getSplatValue<APFloat>();
755 bool exact;
756 floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
757 return SplatElementsAttr::get(outTy, intVal);
758 }
759
760 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
761 auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
762 bool trunc =
763 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
764 auto intVal = operand.getSplatValue<APInt>();
765 auto bitwidth = outETy.getIntOrFloatBitWidth();
766
767 if (trunc) {
768 intVal = intVal.trunc(bitwidth);
769 } else if (unsignIn) {
770 intVal = intVal.zext(bitwidth);
771 } else {
772 intVal = intVal.sext(bitwidth);
773 }
774
775 return SplatElementsAttr::get(outTy, intVal);
776 }
777 }
778
779 return {};
780}
781
782OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
783
784#define REDUCE_FOLDER(OP) \
785 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
786 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
787 if (!inputTy.hasRank()) \
788 return {}; \
789 if (inputTy != getType()) \
790 return {}; \
791 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
792 return getInput(); \
793 return {}; \
794 }
795
796REDUCE_FOLDER(ReduceAllOp)
797REDUCE_FOLDER(ReduceAnyOp)
798REDUCE_FOLDER(ReduceMaxOp)
799REDUCE_FOLDER(ReduceMinOp)
800REDUCE_FOLDER(ReduceProdOp)
801REDUCE_FOLDER(ReduceSumOp)
802#undef REDUCE_FOLDER
803
804OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
805 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
806 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
807
808 if (!inputTy || !outputTy)
809 return {};
810
811 // Fold when the input and output types are the same. This is only safe when
812 // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
813 // there may still be a productive reshape.
814 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
815 return getInput1();
816
817 // reshape(reshape(x)) -> reshape(x)
818 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
819 getInput1().getDefiningOp())) {
820 getInput1Mutable().assign(reshapeOp.getInput1());
821 return getResult();
822 }
823
824 // reshape(const(x)) -> const(reshape-attr(x))
825 if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
826 // Constants must have static shape.
827 if (!outputTy.hasStaticShape())
828 return {};
829
830 // Okay to duplicate splat constants.
831 if (operand.isSplat())
832 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
833
834 // Don't duplicate other constants.
835 if (!getInput1().hasOneUse())
836 return {};
837
838 return operand.reshape(
839 llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
840 }
841
842 return {};
843}
844
845OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
846 // If the pad is all zeros we can fold this operation away.
847 if (adaptor.getPadding()) {
848 auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
849 if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
850 return getInput1();
851 }
852 }
853
854 return {};
855}
856
857// Fold away cases where a tosa.resize operation returns a copy
858// of the input image.
859OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
860 ArrayRef<int64_t> offset = getOffset();
861 ArrayRef<int64_t> border = getBorder();
862 ArrayRef<int64_t> scale = getScale();
863
864 // Check unit scaling.
865 if (scale[0] != scale[1] || scale[2] != scale[3]) {
866 return {};
867 }
868
869 // There should be no offset.
870 if (offset[0] != 0 || offset[1] != 0) {
871 return {};
872 }
873
874 // There should be no border.
875 if (border[0] != 0 || border[1] != 0) {
876 return {};
877 }
878
879 auto input = getInput();
880 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
881 auto resultTy = llvm::cast<RankedTensorType>(getType());
882 if (inputTy != resultTy)
883 return {};
884
885 return input;
886}
887
888OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
889 auto operand = getInput();
890 auto operandTy = llvm::cast<ShapedType>(operand.getType());
891 auto axis = getAxis();
892 auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
893 if (operandAttr)
894 return operandAttr;
895
896 // If the dim-length is 1, tosa.reverse is a no-op.
897 if (operandTy.hasRank() &&
898 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
899 return operand;
900
901 return {};
902}
903
904OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
905 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
906 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
907
908 if (!inputTy || !outputTy)
909 return {};
910
911 if (inputTy == outputTy && inputTy.hasStaticShape())
912 return getInput();
913
914 if (!adaptor.getInput())
915 return {};
916
917 // Cannot create an ElementsAttr from non-int/float/index types
918 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
919 !outputTy.getElementType().isIntOrIndexOrFloat())
920 return {};
921
922 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
923 if (operand.isSplat() && outputTy.hasStaticShape()) {
924 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
925 }
926
927 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
928 outputTy.getNumElements() == 1) {
929 llvm::SmallVector<uint64_t> indices(getStart());
930 auto value = operand.getValues<Attribute>()[indices];
931 return SplatElementsAttr::get(outputTy, value);
932 }
933
934 return {};
935}
936
937OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
938 if (getOnTrue() == getOnFalse())
939 return getOnTrue();
940
941 auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
942 if (!predicate)
943 return {};
944
945 if (!predicate.isSplat())
946 return {};
947 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
948 : getOnFalse();
949}
950
951OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
952 bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
953 if (allOnes && getInput1().getType() == getType())
954 return getInput1();
955 return {};
956}
957
958OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
959 auto inputTy = llvm::cast<ShapedType>(getInput1().getType());
960 auto resultTy = llvm::cast<ShapedType>(getType());
961
962 // Transposing splat values just means reshaping.
963 if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
964 if (input.isSplat() && resultTy.hasStaticShape() &&
965 inputTy.getElementType() == resultTy.getElementType())
966 return input.reshape(resultTy);
967 }
968
969 // Transpose does not change the input type.
970 if (getInput1().getType() != getType())
971 return {};
972
973 // Transpose is not the identity transpose.
974 SmallVector<int64_t> perms;
975 if (getConstantPerms(perms).failed())
976 return {};
977
978 if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
979 return {};
980
981 return getInput1();
982}
983
984OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
985 auto input = getInput1();
986 // Element-wise log(exp(x)) = x
987 if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
988 return op.getInput1();
989 }
990
991 return {};
992}
993
994OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
995 auto input = getInput1();
996 // Element-wise exp(log(x)) = x
997 if (auto op = input.getDefiningOp<tosa::LogOp>()) {
998 return op.getInput1();
999 }
1000
1001 return {};
1002}
1003
1004OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1005 auto input = getInput1();
1006 // Element-wise negate(negate(x)) = x
1007 if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
1008 return op.getInput1();
1009 }
1010
1011 return {};
1012}
1013
1014OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1015 auto input = getInput1();
1016 // Element-wise abs(abs(x)) = abs(x)
1017 if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1018 return input;
1019 }
1020
1021 return {};
1022}
1023
1024OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1025 // Fold consecutive concats on the same axis into a single op.
1026 // Keep track of the operands so we are able to construct a new concat
1027 // later. Conservatively assume that we double the number of operands when
1028 // folding
1029 SmallVector<Value, 8> concatOperands;
1030 concatOperands.reserve(2 * getNumOperands());
1031
1032 // Find all operands that are foldable concats
1033 bool foundFoldableConcat = false;
1034 for (Value operand : getOperands()) {
1035 concatOperands.emplace_back(operand);
1036
1037 auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1038 if (!producer)
1039 continue;
1040
1041 // Not foldable if axes are not the same
1042 if (getAxis() != producer.getAxis())
1043 continue;
1044
1045 // Replace the original operand with all incoming operands
1046 foundFoldableConcat = true;
1047 concatOperands.pop_back();
1048 llvm::append_range(concatOperands, producer->getOperands());
1049 }
1050
1051 if (!foundFoldableConcat)
1052 return {};
1053
1054 getOperation()->setOperands(concatOperands);
1055 return getResult();
1056}
1057
1058OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1059 auto input = adaptor.getInput1();
1060
1061 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1062 // Fold splat inputs only.
1063 if (!inputAttr || !inputAttr.isSplat())
1064 return {};
1065
1066 auto shapeType = llvm::cast<ShapedType>(getType());
1067 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1068 auto floatVal = inputAttr.getSplatValue<APFloat>();
1069 return DenseElementsAttr::get(shapeType,
1070 ReciprocalOp::calcOneElement(floatVal));
1071 }
1072
1073 return {};
1074}
1075

source code of mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp