1 | //===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // \file |
10 | // TOSA canonicalization patterns and folders. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Quant/QuantOps.h" |
15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
16 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
17 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
18 | #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" |
19 | #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" |
20 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
21 | #include "mlir/IR/BuiltinTypes.h" |
22 | #include "mlir/IR/DialectImplementation.h" |
23 | #include "mlir/IR/Matchers.h" |
24 | #include "mlir/IR/PatternMatch.h" |
25 | #include "mlir/Transforms/FoldUtils.h" |
26 | #include "mlir/Transforms/InliningUtils.h" |
27 | #include "mlir/Transforms/RegionUtils.h" |
28 | #include "llvm/ADT/APFloat.h" |
29 | #include "llvm/ADT/APInt.h" |
30 | #include "llvm/ADT/DenseMap.h" |
31 | #include "llvm/ADT/TypeSwitch.h" |
32 | |
33 | #include <functional> |
34 | |
35 | using namespace mlir; |
36 | using namespace mlir::tosa; |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Operator Canonicalizers. |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> { |
43 | using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern; |
44 | |
45 | LogicalResult matchAndRewrite(tosa::ConcatOp op, |
46 | PatternRewriter &rewriter) const override { |
47 | if (op.getInput1().size() != 1) |
48 | return failure(); |
49 | if (op.getInput1().front().getType() != op.getType()) { |
50 | rewriter |
51 | .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), |
52 | op.getInput1().front()) |
53 | .getResult(); |
54 | return success(); |
55 | } |
56 | |
57 | rewriter.replaceOp(op, op.getInput1().front()); |
58 | return success(); |
59 | } |
60 | }; |
61 | |
62 | void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, |
63 | MLIRContext *context) { |
64 | results.add<ConcatOptimization>(context); |
65 | } |
66 | |
67 | LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { |
68 | auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>(); |
69 | if (!notOp) |
70 | return failure(); |
71 | rewriter.modifyOpInPlace(op, [&]() { |
72 | op.getOperation()->setOperands( |
73 | {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()}); |
74 | }); |
75 | return success(); |
76 | } |
77 | |
78 | struct ConsolidateTransposeOptimization |
79 | : public OpRewritePattern<tosa::TransposeOp> { |
80 | using OpRewritePattern::OpRewritePattern; |
81 | |
82 | LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, |
83 | PatternRewriter &rewriter) const override { |
84 | // Input is also TransposeOp - transpose(transpose(A)). |
85 | auto innerTranspose = |
86 | transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>(); |
87 | if (!innerTranspose) |
88 | return rewriter.notifyMatchFailure(transposeOp, |
89 | "input must be transpose operation" ); |
90 | |
91 | SmallVector<int64_t> transposePerms, innerTransposePerms; |
92 | if (transposeOp.getConstantPerms(transposePerms).failed()) |
93 | return rewriter.notifyMatchFailure(transposeOp, |
94 | "transpose perms must be constant" ); |
95 | if (innerTranspose.getConstantPerms(innerTransposePerms).failed()) |
96 | return rewriter.notifyMatchFailure( |
97 | transposeOp, "inner transpose perms must be constant" ); |
98 | if (transposePerms.size() != innerTransposePerms.size()) |
99 | return rewriter.notifyMatchFailure( |
100 | transposeOp, |
101 | "transpose and inner transpose perms sizes must be equal" ); |
102 | if (transposePerms.empty()) |
103 | return rewriter.notifyMatchFailure( |
104 | transposeOp, "transpose perms sizes must be positive" ); |
105 | |
106 | // Consolidate transposes into one transpose. |
107 | SmallVector<int32_t> perms(transposePerms.size()); |
108 | for (int i = 0, s = transposePerms.size(); i < s; ++i) |
109 | perms[i] = innerTransposePerms[transposePerms[i]]; |
110 | |
111 | auto permsTy = |
112 | RankedTensorType::get(transposePerms.size(), rewriter.getI32Type()); |
113 | auto permsAttr = DenseIntElementsAttr::get(permsTy, perms); |
114 | Value permsValue = |
115 | rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr); |
116 | |
117 | rewriter.replaceOpWithNewOp<tosa::TransposeOp>( |
118 | transposeOp, transposeOp.getResult().getType(), |
119 | innerTranspose.getInput1(), permsValue); |
120 | |
121 | return success(); |
122 | } |
123 | }; |
124 | |
125 | // Determines the case when tosa.transpose is a tosa.reshape operation. |
126 | struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> { |
127 | using OpRewritePattern::OpRewritePattern; |
128 | |
129 | LogicalResult matchAndRewrite(tosa::TransposeOp op, |
130 | PatternRewriter &rewriter) const override { |
131 | DenseIntElementsAttr permAttr; |
132 | if (!matchPattern(op.getPerms(), m_Constant(bind_value: &permAttr))) |
133 | return rewriter.notifyMatchFailure(op, "Non-constant permutation" ); |
134 | |
135 | if (op.getInput1().getDefiningOp<tosa::TransposeOp>()) |
136 | return rewriter.notifyMatchFailure( |
137 | op, "Src is from transpose, can compose transposes" ); |
138 | |
139 | Value result = op.getResult(); |
140 | for (Operation *subop : result.getUsers()) { |
141 | if (dyn_cast_or_null<tosa::TransposeOp>(subop)) |
142 | return rewriter.notifyMatchFailure( |
143 | op, "Dest is used by transpose, can compose transposes" ); |
144 | } |
145 | |
146 | auto input = op.getInput1(); |
147 | auto inputTy = llvm::cast<ShapedType>(input.getType()); |
148 | if (!inputTy.hasRank()) |
149 | return rewriter.notifyMatchFailure(op, "Unranked input." ); |
150 | |
151 | int64_t numDynDims = 0; |
152 | for (int i = 0; i < inputTy.getRank(); ++i) |
153 | if (inputTy.isDynamicDim(i)) |
154 | numDynDims++; |
155 | |
156 | if (numDynDims > 1) |
157 | return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim." ); |
158 | |
159 | SmallVector<int64_t> permValues = llvm::to_vector<6>( |
160 | llvm::map_range(permAttr.getValues<APInt>(), |
161 | [](const APInt &val) { return val.getSExtValue(); })); |
162 | |
163 | SmallVector<int64_t> nonZeroPerms; |
164 | nonZeroPerms.reserve(N: permValues.size()); |
165 | for (auto idx : permValues) { |
166 | auto sz = inputTy.getDimSize(idx); |
167 | if (sz != 1) |
168 | nonZeroPerms.push_back(Elt: idx); |
169 | } |
170 | |
171 | for (int i = 1, s = nonZeroPerms.size(); i < s; ++i) |
172 | if (nonZeroPerms[i - 1] > nonZeroPerms[i]) |
173 | return rewriter.notifyMatchFailure(op, |
174 | "Transpose changes memory layout." ); |
175 | |
176 | SmallVector<int64_t> newShape; |
177 | newShape.reserve(N: inputTy.getRank()); |
178 | for (int i = 0, s = inputTy.getRank(); i < s; ++i) |
179 | newShape.push_back(Elt: inputTy.getDimSize(permValues[i])); |
180 | |
181 | rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( |
182 | op, op.getType(), op.getInput1(), |
183 | rewriter.getDenseI64ArrayAttr(newShape)); |
184 | return success(); |
185 | } |
186 | }; |
187 | |
188 | void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, |
189 | MLIRContext *context) { |
190 | results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context); |
191 | } |
192 | |
193 | struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> { |
194 | using OpRewritePattern::OpRewritePattern; |
195 | |
196 | LogicalResult matchAndRewrite(tosa::PadOp op, |
197 | PatternRewriter &rewriter) const override { |
198 | if (op.getPadConst()) |
199 | return failure(); |
200 | |
201 | auto input = op.getInput1(); |
202 | auto padding = op.getPadding(); |
203 | |
204 | ShapedType inputTy = llvm::cast<ShapedType>(input.getType()); |
205 | Type elementTy = inputTy.getElementType(); |
206 | |
207 | Attribute constantAttr; |
208 | if (llvm::isa<FloatType>(Val: elementTy)) { |
209 | constantAttr = rewriter.getFloatAttr(elementTy, 0.0); |
210 | } else if (llvm::isa<IntegerType>(Val: elementTy) && !op.getQuantizationInfo()) { |
211 | constantAttr = rewriter.getIntegerAttr(elementTy, 0); |
212 | } else if (llvm::isa<IntegerType>(Val: elementTy) && op.getQuantizationInfo()) { |
213 | auto value = op.getQuantizationInfo()->getInputZp(); |
214 | constantAttr = rewriter.getIntegerAttr(elementTy, value); |
215 | } |
216 | |
217 | if (!constantAttr) { |
218 | return rewriter.notifyMatchFailure( |
219 | op, |
220 | "tosa.pad to linalg lowering encountered an unknown element type" ); |
221 | } |
222 | |
223 | auto denseAttr = DenseElementsAttr::get( |
224 | RankedTensorType::get({}, elementTy), constantAttr); |
225 | auto constantVal = rewriter.create<tosa::ConstOp>( |
226 | op.getLoc(), denseAttr.getType(), denseAttr); |
227 | |
228 | rewriter.replaceOpWithNewOp<tosa::PadOp>( |
229 | op, op.getType(), ValueRange{input, padding, constantVal}, |
230 | op->getAttrs()); |
231 | return success(); |
232 | } |
233 | }; |
234 | |
235 | void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, |
236 | MLIRContext *context) { |
237 | results.add<MaterializePadValue>(context); |
238 | } |
239 | |
240 | struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> { |
241 | using OpRewritePattern::OpRewritePattern; |
242 | |
243 | LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, |
244 | PatternRewriter &rewriter) const override { |
245 | Value input = op.getInput(); |
246 | Value output = op.getOutput(); |
247 | ShapedType inputType = llvm::cast<ShapedType>(input.getType()); |
248 | ShapedType outputType = llvm::cast<ShapedType>(output.getType()); |
249 | |
250 | if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { |
251 | return failure(); |
252 | } |
253 | |
254 | // If the output and input shapes are 1x1, then this is a no op. |
255 | ArrayRef<int64_t> outputShape = outputType.getShape(); |
256 | if (outputShape[1] != 1 || outputShape[2] != 1) { |
257 | return failure(); |
258 | } |
259 | |
260 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
261 | if (inputShape[1] != 1 || inputShape[2] != 1) { |
262 | return failure(); |
263 | } |
264 | |
265 | rewriter.replaceOp(op, input); |
266 | return success(); |
267 | } |
268 | }; |
269 | |
270 | void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, |
271 | MLIRContext *context) { |
272 | results.add<MaxPool2dIsNoOp>(context); |
273 | } |
274 | |
275 | struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> { |
276 | using OpRewritePattern::OpRewritePattern; |
277 | |
278 | LogicalResult matchAndRewrite(tosa::ClampOp op, |
279 | PatternRewriter &rewriter) const override { |
280 | Value input = op.getInput(); |
281 | auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType()); |
282 | auto inputElementType = inputType.getElementType(); |
283 | |
284 | if (!inputType.hasStaticShape()) { |
285 | return failure(); |
286 | } |
287 | |
288 | if (isa<FloatType>(inputElementType)) { |
289 | // Unlike integer types, floating point types can represent infinity. |
290 | auto minClamp = op.getMinFp(); |
291 | auto maxClamp = op.getMaxFp(); |
292 | bool isMin = minClamp.isInfinity() && minClamp.isNegative(); |
293 | bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative(); |
294 | |
295 | if (isMin && isMax) { |
296 | rewriter.replaceOp(op, input); |
297 | return success(); |
298 | } |
299 | return failure(); |
300 | } |
301 | |
302 | if (inputElementType.isUnsignedInteger()) { |
303 | int64_t minClamp = op.getMinInt(); |
304 | int64_t maxClamp = op.getMaxInt(); |
305 | |
306 | int64_t intMin = |
307 | APInt::getMinValue(numBits: inputElementType.getIntOrFloatBitWidth()) |
308 | .getZExtValue(); |
309 | int64_t intMax = |
310 | APInt::getMaxValue(numBits: inputElementType.getIntOrFloatBitWidth()) |
311 | .getZExtValue(); |
312 | |
313 | if (minClamp <= intMin && maxClamp >= intMax) { |
314 | rewriter.replaceOp(op, input); |
315 | return success(); |
316 | } |
317 | return failure(); |
318 | } |
319 | |
320 | if (llvm::isa<IntegerType>(inputElementType)) { |
321 | int64_t minClamp = op.getMinInt(); |
322 | int64_t maxClamp = op.getMaxInt(); |
323 | |
324 | int64_t intMin = |
325 | APInt::getSignedMinValue(numBits: inputElementType.getIntOrFloatBitWidth()) |
326 | .getSExtValue(); |
327 | int64_t intMax = |
328 | APInt::getSignedMaxValue(numBits: inputElementType.getIntOrFloatBitWidth()) |
329 | .getSExtValue(); |
330 | |
331 | if (minClamp <= intMin && maxClamp >= intMax) { |
332 | rewriter.replaceOp(op, input); |
333 | return success(); |
334 | } |
335 | return failure(); |
336 | } |
337 | |
338 | return failure(); |
339 | } |
340 | }; |
341 | |
342 | struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> { |
343 | using OpRewritePattern<tosa::ClampOp>::OpRewritePattern; |
344 | |
345 | LogicalResult matchAndRewrite(tosa::ClampOp op, |
346 | PatternRewriter &rewriter) const override { |
347 | Value input = op.getInput(); |
348 | |
349 | Operation *definingOp = input.getDefiningOp(); |
350 | if (!definingOp) |
351 | return failure(); |
352 | |
353 | if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) { |
354 | auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat(); |
355 | auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat(); |
356 | |
357 | auto minInt = std::max(op.getMinInt(), clampOp.getMinInt()); |
358 | auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt()); |
359 | |
360 | rewriter.replaceOpWithNewOp<tosa::ClampOp>( |
361 | op, op.getType(), clampOp.getInput(), |
362 | rewriter.getI64IntegerAttr(minInt), |
363 | rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp), |
364 | rewriter.getF32FloatAttr(maxFp)); |
365 | return success(); |
366 | } |
367 | |
368 | return failure(); |
369 | } |
370 | }; |
371 | |
372 | void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results, |
373 | MLIRContext *context) { |
374 | results.add<ClampIsNoOp>(context); |
375 | results.add<ClampClampOptimization>(context); |
376 | } |
377 | |
378 | struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> { |
379 | using OpRewritePattern<tosa::SliceOp>::OpRewritePattern; |
380 | |
381 | LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, |
382 | PatternRewriter &rewriter) const override { |
383 | Value sliceInput = sliceOp.getInput(); |
384 | auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>(); |
385 | if (!concatOp) |
386 | return rewriter.notifyMatchFailure( |
387 | sliceOp, "slice input must be concat operation" ); |
388 | |
389 | OperandRange inputs = concatOp.getInput1(); |
390 | auto concatType = dyn_cast<RankedTensorType>(concatOp.getType()); |
391 | if (!concatType || !concatType.hasStaticShape()) |
392 | return rewriter.notifyMatchFailure( |
393 | sliceOp, "slice input must be a static ranked tensor" ); |
394 | int32_t axis = concatOp.getAxis(); |
395 | |
396 | llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart()); |
397 | llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize(); |
398 | |
399 | // Validate slice on the concatenated axis. Slicing along this |
400 | // axis should span only one of the inputs to the concatenate |
401 | // operation. |
402 | std::optional<Value> replaceWithSlice; |
403 | for (auto input : inputs) { |
404 | auto inputType = dyn_cast<RankedTensorType>(input.getType()); |
405 | if (!inputType || !inputType.hasStaticShape()) |
406 | return rewriter.notifyMatchFailure( |
407 | sliceOp, "concat input must be a static ranked tensor" ); |
408 | |
409 | if (sliceStart[axis] >= 0 && |
410 | (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) { |
411 | replaceWithSlice = rewriter |
412 | .create<tosa::SliceOp>( |
413 | sliceOp.getLoc(), sliceOp.getType(), input, |
414 | rewriter.getDenseI64ArrayAttr(sliceStart), |
415 | rewriter.getDenseI64ArrayAttr(sliceSize)) |
416 | .getResult(); |
417 | break; |
418 | } |
419 | sliceStart[axis] -= inputType.getDimSize(axis); |
420 | } |
421 | |
422 | if (!replaceWithSlice) |
423 | return rewriter.notifyMatchFailure( |
424 | sliceOp, "corresponding concat input not found for slice" ); |
425 | |
426 | rewriter.replaceOp(sliceOp, replaceWithSlice.value()); |
427 | return success(); |
428 | } |
429 | }; |
430 | |
431 | void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, |
432 | MLIRContext *context) { |
433 | results.add<ConcatSliceOptimization>(context); |
434 | } |
435 | |
436 | //===----------------------------------------------------------------------===// |
437 | // Operator Folders. |
438 | //===----------------------------------------------------------------------===// |
439 | |
440 | template <typename IntFolder, typename FloatFolder> |
441 | DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, |
442 | RankedTensorType returnTy) { |
443 | if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { |
444 | auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType(); |
445 | auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType(); |
446 | if (lETy != rETy) |
447 | return {}; |
448 | |
449 | if (llvm::isa<IntegerType>(lETy)) { |
450 | APInt l = lhs.getSplatValue<APInt>(); |
451 | APInt r = rhs.getSplatValue<APInt>(); |
452 | auto result = IntFolder()(l, r); |
453 | return DenseElementsAttr::get(returnTy, result); |
454 | } |
455 | |
456 | if (llvm::isa<FloatType>(lETy)) { |
457 | APFloat l = lhs.getSplatValue<APFloat>(); |
458 | APFloat r = rhs.getSplatValue<APFloat>(); |
459 | auto result = FloatFolder()(l, r); |
460 | return DenseElementsAttr::get(returnTy, result); |
461 | } |
462 | } |
463 | |
464 | return {}; |
465 | } |
466 | |
467 | static bool isSplatZero(Type elemType, DenseElementsAttr val) { |
468 | if (llvm::isa<FloatType>(Val: elemType)) |
469 | return val && val.isSplat() && val.getSplatValue<APFloat>().isZero(); |
470 | if (llvm::isa<IntegerType>(Val: elemType)) |
471 | return val && val.isSplat() && val.getSplatValue<APInt>().isZero(); |
472 | return false; |
473 | } |
474 | |
475 | static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) { |
476 | if (llvm::isa<FloatType>(Val: elemType)) |
477 | return val && val.isSplat() && |
478 | val.getSplatValue<APFloat>().isExactlyValue(V: 1.0); |
479 | if (llvm::isa<IntegerType>(Val: elemType)) { |
480 | const int64_t shifted = 1LL << shift; |
481 | return val && val.isSplat() && |
482 | val.getSplatValue<APInt>().getSExtValue() == shifted; |
483 | } |
484 | return false; |
485 | } |
486 | |
487 | OpFoldResult AddOp::fold(FoldAdaptor adaptor) { |
488 | auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType()); |
489 | auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType()); |
490 | auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); |
491 | if (!lhsTy || !rhsTy || !resultTy) |
492 | return {}; |
493 | |
494 | auto resultETy = resultTy.getElementType(); |
495 | auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); |
496 | auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); |
497 | |
498 | if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr)) |
499 | return getInput1(); |
500 | if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr)) |
501 | return getInput2(); |
502 | |
503 | if (!lhsAttr || !rhsAttr) |
504 | return {}; |
505 | |
506 | return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr, |
507 | resultTy); |
508 | } |
509 | |
510 | OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) { |
511 | auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType()); |
512 | auto outputTy = llvm::dyn_cast<RankedTensorType>(getType()); |
513 | if (!inputTy || !outputTy || !inputTy.hasStaticShape() || |
514 | !outputTy.hasStaticShape()) |
515 | return {}; |
516 | |
517 | if (inputTy.getDimSize(getAxis()) == 1) |
518 | return DenseElementsAttr::get(outputTy, 0); |
519 | |
520 | return {}; |
521 | } |
522 | |
523 | OpFoldResult DivOp::fold(FoldAdaptor adaptor) { |
524 | auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType()); |
525 | auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType()); |
526 | auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); |
527 | if (!lhsTy || !rhsTy || !resultTy) |
528 | return {}; |
529 | if (lhsTy != rhsTy) |
530 | return {}; |
531 | |
532 | auto resultETy = resultTy.getElementType(); |
533 | auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); |
534 | auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); |
535 | if (lhsAttr && lhsAttr.isSplat()) { |
536 | if (llvm::isa<IntegerType>(resultETy) && |
537 | lhsAttr.getSplatValue<APInt>().isZero()) |
538 | return lhsAttr; |
539 | } |
540 | |
541 | if (rhsAttr && rhsAttr.isSplat()) { |
542 | if (llvm::isa<IntegerType>(resultETy) && |
543 | rhsAttr.getSplatValue<APInt>().isOne()) |
544 | return getInput1(); |
545 | } |
546 | |
547 | if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) { |
548 | if (llvm::isa<IntegerType>(resultETy)) { |
549 | APInt l = lhsAttr.getSplatValue<APInt>(); |
550 | APInt r = rhsAttr.getSplatValue<APInt>(); |
551 | APInt result = l.sdiv(r); |
552 | return DenseElementsAttr::get(resultTy, result); |
553 | } |
554 | } |
555 | |
556 | return {}; |
557 | } |
558 | |
559 | namespace { |
560 | DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, |
561 | RankedTensorType ty, int32_t shift) { |
562 | if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { |
563 | if (llvm::isa<IntegerType>(ty.getElementType())) { |
564 | APInt l = lhs.getSplatValue<APInt>(); |
565 | APInt r = rhs.getSplatValue<APInt>(); |
566 | |
567 | if (shift == 0) { |
568 | return DenseElementsAttr::get(ty, l * r); |
569 | } |
570 | |
571 | auto bitwidth = ty.getElementType().getIntOrFloatBitWidth(); |
572 | l = l.sext(width: bitwidth * 2); |
573 | r = r.sext(width: bitwidth * 2); |
574 | auto result = l * r; |
575 | result.lshrInPlace(ShiftAmt: shift); |
576 | result = result.trunc(width: bitwidth); |
577 | return DenseElementsAttr::get(ty, result); |
578 | } |
579 | |
580 | if (llvm::isa<FloatType>(ty.getElementType())) { |
581 | APFloat l = lhs.getSplatValue<APFloat>(); |
582 | APFloat r = rhs.getSplatValue<APFloat>(); |
583 | APFloat result = l * r; |
584 | return DenseElementsAttr::get(ty, result); |
585 | } |
586 | } |
587 | |
588 | return {}; |
589 | } |
590 | } // namespace |
591 | |
592 | OpFoldResult MulOp::fold(FoldAdaptor adaptor) { |
593 | auto lhs = getInput1(); |
594 | auto rhs = getInput2(); |
595 | auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType()); |
596 | auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType()); |
597 | auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); |
598 | if (!lhsTy || !rhsTy || !resultTy) |
599 | return {}; |
600 | |
601 | auto resultETy = resultTy.getElementType(); |
602 | auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); |
603 | auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); |
604 | |
605 | const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0; |
606 | if (rhsTy == resultTy) { |
607 | if (isSplatZero(resultETy, lhsAttr)) |
608 | return lhsAttr.resizeSplat(resultTy); |
609 | if (isSplatOne(resultETy, lhsAttr, shift)) |
610 | return rhs; |
611 | } |
612 | if (lhsTy == resultTy) { |
613 | if (isSplatZero(resultETy, rhsAttr)) |
614 | return rhsAttr.resizeSplat(resultTy); |
615 | if (isSplatOne(resultETy, rhsAttr, shift)) |
616 | return lhs; |
617 | } |
618 | |
619 | return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift()); |
620 | } |
621 | |
622 | OpFoldResult SubOp::fold(FoldAdaptor adaptor) { |
623 | auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType()); |
624 | auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType()); |
625 | auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); |
626 | if (!lhsTy || !rhsTy || !resultTy) |
627 | return {}; |
628 | |
629 | auto resultETy = resultTy.getElementType(); |
630 | auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); |
631 | auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); |
632 | |
633 | if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr)) |
634 | return getInput1(); |
635 | |
636 | if (!lhsAttr || !rhsAttr) |
637 | return {}; |
638 | |
639 | return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr, |
640 | resultTy); |
641 | } |
642 | |
643 | namespace { |
644 | template <typename Cmp> |
645 | struct ComparisonFold { |
646 | ComparisonFold() = default; |
647 | APInt operator()(const APInt &l, const APInt &r) { |
648 | return APInt(1, Cmp()(l, r)); |
649 | } |
650 | |
651 | APInt operator()(const APFloat &l, const APFloat &r) { |
652 | return APInt(1, Cmp()(l, r)); |
653 | } |
654 | }; |
655 | |
656 | struct APIntFoldGreater { |
657 | APIntFoldGreater() = default; |
658 | APInt operator()(const APInt &l, const APInt &r) { |
659 | return APInt(1, l.sgt(RHS: r)); |
660 | } |
661 | }; |
662 | |
663 | struct APIntFoldGreaterEqual { |
664 | APIntFoldGreaterEqual() = default; |
665 | APInt operator()(const APInt &l, const APInt &r) { |
666 | return APInt(1, l.sge(RHS: r)); |
667 | } |
668 | }; |
669 | } // namespace |
670 | |
671 | OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { |
672 | auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); |
673 | auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); |
674 | auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); |
675 | |
676 | if (!lhsAttr || !rhsAttr) |
677 | return {}; |
678 | |
679 | return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>( |
680 | lhsAttr, rhsAttr, resultTy); |
681 | } |
682 | |
683 | OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { |
684 | auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); |
685 | auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); |
686 | auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); |
687 | |
688 | if (!lhsAttr || !rhsAttr) |
689 | return {}; |
690 | |
691 | return binaryFolder<APIntFoldGreaterEqual, |
692 | ComparisonFold<std::greater_equal<APFloat>>>( |
693 | lhsAttr, rhsAttr, resultTy); |
694 | } |
695 | |
696 | OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { |
697 | auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); |
698 | auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); |
699 | auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); |
700 | Value lhs = getInput1(); |
701 | Value rhs = getInput2(); |
702 | auto lhsTy = llvm::cast<ShapedType>(lhs.getType()); |
703 | |
704 | // If we are comparing an integer value to itself it is always true. We can |
705 | // not do this with float due to float values. |
706 | if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy && |
707 | resultTy.hasStaticShape() && lhs == rhs) { |
708 | return DenseElementsAttr::get(resultTy, true); |
709 | } |
710 | |
711 | if (!lhsAttr || !rhsAttr) |
712 | return {}; |
713 | |
714 | return binaryFolder<ComparisonFold<std::equal_to<APInt>>, |
715 | ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr, |
716 | resultTy); |
717 | } |
718 | |
719 | OpFoldResult CastOp::fold(FoldAdaptor adaptor) { |
720 | if (getInput().getType() == getType()) |
721 | return getInput(); |
722 | |
723 | auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput()); |
724 | if (!operand) |
725 | return {}; |
726 | |
727 | auto inTy = llvm::cast<ShapedType>(getInput().getType()); |
728 | auto outTy = llvm::cast<ShapedType>(getType()); |
729 | auto inETy = inTy.getElementType(); |
730 | auto outETy = outTy.getElementType(); |
731 | |
732 | if (operand.isSplat()) { |
733 | if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) { |
734 | bool overflow; |
735 | auto splatVal = operand.getSplatValue<APFloat>(); |
736 | auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics(); |
737 | splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven, |
738 | &overflow); |
739 | return SplatElementsAttr::get(outTy, splatVal); |
740 | } |
741 | |
742 | if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) { |
743 | auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger(); |
744 | APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics()); |
745 | splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign, |
746 | llvm::RoundingMode::NearestTiesToEven); |
747 | return SplatElementsAttr::get(outTy, splatVal); |
748 | } |
749 | |
750 | if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) { |
751 | auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger(); |
752 | auto intVal = APSInt( |
753 | llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign); |
754 | auto floatVal = operand.getSplatValue<APFloat>(); |
755 | bool exact; |
756 | floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact); |
757 | return SplatElementsAttr::get(outTy, intVal); |
758 | } |
759 | |
760 | if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) { |
761 | auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger(); |
762 | bool trunc = |
763 | inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth(); |
764 | auto intVal = operand.getSplatValue<APInt>(); |
765 | auto bitwidth = outETy.getIntOrFloatBitWidth(); |
766 | |
767 | if (trunc) { |
768 | intVal = intVal.trunc(bitwidth); |
769 | } else if (unsignIn) { |
770 | intVal = intVal.zext(bitwidth); |
771 | } else { |
772 | intVal = intVal.sext(bitwidth); |
773 | } |
774 | |
775 | return SplatElementsAttr::get(outTy, intVal); |
776 | } |
777 | } |
778 | |
779 | return {}; |
780 | } |
781 | |
782 | OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } |
783 | |
784 | #define REDUCE_FOLDER(OP) \ |
785 | OpFoldResult OP::fold(FoldAdaptor adaptor) { \ |
786 | ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \ |
787 | if (!inputTy.hasRank()) \ |
788 | return {}; \ |
789 | if (inputTy != getType()) \ |
790 | return {}; \ |
791 | if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \ |
792 | return getInput(); \ |
793 | return {}; \ |
794 | } |
795 | |
796 | REDUCE_FOLDER(ReduceAllOp) |
797 | REDUCE_FOLDER(ReduceAnyOp) |
798 | REDUCE_FOLDER(ReduceMaxOp) |
799 | REDUCE_FOLDER(ReduceMinOp) |
800 | REDUCE_FOLDER(ReduceProdOp) |
801 | REDUCE_FOLDER(ReduceSumOp) |
802 | #undef REDUCE_FOLDER |
803 | |
804 | OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { |
805 | auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType()); |
806 | auto outputTy = llvm::dyn_cast<RankedTensorType>(getType()); |
807 | |
808 | if (!inputTy || !outputTy) |
809 | return {}; |
810 | |
811 | // Fold when the input and output types are the same. This is only safe when |
812 | // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions, |
813 | // there may still be a productive reshape. |
814 | if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2) |
815 | return getInput1(); |
816 | |
817 | // reshape(reshape(x)) -> reshape(x) |
818 | if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>( |
819 | getInput1().getDefiningOp())) { |
820 | getInput1Mutable().assign(reshapeOp.getInput1()); |
821 | return getResult(); |
822 | } |
823 | |
824 | // reshape(const(x)) -> const(reshape-attr(x)) |
825 | if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) { |
826 | // Constants must have static shape. |
827 | if (!outputTy.hasStaticShape()) |
828 | return {}; |
829 | |
830 | // Okay to duplicate splat constants. |
831 | if (operand.isSplat()) |
832 | return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>()); |
833 | |
834 | // Don't duplicate other constants. |
835 | if (!getInput1().hasOneUse()) |
836 | return {}; |
837 | |
838 | return operand.reshape( |
839 | llvm::cast<ShapedType>(operand.getType()).clone(getNewShape())); |
840 | } |
841 | |
842 | return {}; |
843 | } |
844 | |
845 | OpFoldResult PadOp::fold(FoldAdaptor adaptor) { |
846 | // If the pad is all zeros we can fold this operation away. |
847 | if (adaptor.getPadding()) { |
848 | auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding()); |
849 | if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) { |
850 | return getInput1(); |
851 | } |
852 | } |
853 | |
854 | return {}; |
855 | } |
856 | |
857 | // Fold away cases where a tosa.resize operation returns a copy |
858 | // of the input image. |
859 | OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) { |
860 | ArrayRef<int64_t> offset = getOffset(); |
861 | ArrayRef<int64_t> border = getBorder(); |
862 | ArrayRef<int64_t> scale = getScale(); |
863 | |
864 | // Check unit scaling. |
865 | if (scale[0] != scale[1] || scale[2] != scale[3]) { |
866 | return {}; |
867 | } |
868 | |
869 | // There should be no offset. |
870 | if (offset[0] != 0 || offset[1] != 0) { |
871 | return {}; |
872 | } |
873 | |
874 | // There should be no border. |
875 | if (border[0] != 0 || border[1] != 0) { |
876 | return {}; |
877 | } |
878 | |
879 | auto input = getInput(); |
880 | auto inputTy = llvm::cast<RankedTensorType>(input.getType()); |
881 | auto resultTy = llvm::cast<RankedTensorType>(getType()); |
882 | if (inputTy != resultTy) |
883 | return {}; |
884 | |
885 | return input; |
886 | } |
887 | |
888 | OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { |
889 | auto operand = getInput(); |
890 | auto operandTy = llvm::cast<ShapedType>(operand.getType()); |
891 | auto axis = getAxis(); |
892 | auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput()); |
893 | if (operandAttr) |
894 | return operandAttr; |
895 | |
896 | // If the dim-length is 1, tosa.reverse is a no-op. |
897 | if (operandTy.hasRank() && |
898 | (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1)) |
899 | return operand; |
900 | |
901 | return {}; |
902 | } |
903 | |
904 | OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { |
905 | auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType()); |
906 | auto outputTy = llvm::dyn_cast<RankedTensorType>(getType()); |
907 | |
908 | if (!inputTy || !outputTy) |
909 | return {}; |
910 | |
911 | if (inputTy == outputTy && inputTy.hasStaticShape()) |
912 | return getInput(); |
913 | |
914 | if (!adaptor.getInput()) |
915 | return {}; |
916 | |
917 | // Cannot create an ElementsAttr from non-int/float/index types |
918 | if (!inputTy.getElementType().isIntOrIndexOrFloat() || |
919 | !outputTy.getElementType().isIntOrIndexOrFloat()) |
920 | return {}; |
921 | |
922 | auto operand = llvm::cast<ElementsAttr>(adaptor.getInput()); |
923 | if (operand.isSplat() && outputTy.hasStaticShape()) { |
924 | return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>()); |
925 | } |
926 | |
927 | if (inputTy.hasStaticShape() && outputTy.hasStaticShape() && |
928 | outputTy.getNumElements() == 1) { |
929 | llvm::SmallVector<uint64_t> indices(getStart()); |
930 | auto value = operand.getValues<Attribute>()[indices]; |
931 | return SplatElementsAttr::get(outputTy, value); |
932 | } |
933 | |
934 | return {}; |
935 | } |
936 | |
937 | OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { |
938 | if (getOnTrue() == getOnFalse()) |
939 | return getOnTrue(); |
940 | |
941 | auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred()); |
942 | if (!predicate) |
943 | return {}; |
944 | |
945 | if (!predicate.isSplat()) |
946 | return {}; |
947 | return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue() |
948 | : getOnFalse(); |
949 | } |
950 | |
951 | OpFoldResult TileOp::fold(FoldAdaptor adaptor) { |
952 | bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; }); |
953 | if (allOnes && getInput1().getType() == getType()) |
954 | return getInput1(); |
955 | return {}; |
956 | } |
957 | |
958 | OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { |
959 | auto inputTy = llvm::cast<ShapedType>(getInput1().getType()); |
960 | auto resultTy = llvm::cast<ShapedType>(getType()); |
961 | |
962 | // Transposing splat values just means reshaping. |
963 | if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) { |
964 | if (input.isSplat() && resultTy.hasStaticShape() && |
965 | inputTy.getElementType() == resultTy.getElementType()) |
966 | return input.reshape(resultTy); |
967 | } |
968 | |
969 | // Transpose does not change the input type. |
970 | if (getInput1().getType() != getType()) |
971 | return {}; |
972 | |
973 | // Transpose is not the identity transpose. |
974 | SmallVector<int64_t> perms; |
975 | if (getConstantPerms(perms).failed()) |
976 | return {}; |
977 | |
978 | if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms)) |
979 | return {}; |
980 | |
981 | return getInput1(); |
982 | } |
983 | |
984 | OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) { |
985 | auto input = getInput1(); |
986 | // Element-wise log(exp(x)) = x |
987 | if (auto op = input.getDefiningOp<tosa::ExpOp>()) { |
988 | return op.getInput1(); |
989 | } |
990 | |
991 | return {}; |
992 | } |
993 | |
994 | OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) { |
995 | auto input = getInput1(); |
996 | // Element-wise exp(log(x)) = x |
997 | if (auto op = input.getDefiningOp<tosa::LogOp>()) { |
998 | return op.getInput1(); |
999 | } |
1000 | |
1001 | return {}; |
1002 | } |
1003 | |
1004 | OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) { |
1005 | auto input = getInput1(); |
1006 | // Element-wise negate(negate(x)) = x |
1007 | if (auto op = input.getDefiningOp<tosa::NegateOp>()) { |
1008 | return op.getInput1(); |
1009 | } |
1010 | |
1011 | return {}; |
1012 | } |
1013 | |
1014 | OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) { |
1015 | auto input = getInput1(); |
1016 | // Element-wise abs(abs(x)) = abs(x) |
1017 | if (auto op = input.getDefiningOp<tosa::AbsOp>()) { |
1018 | return input; |
1019 | } |
1020 | |
1021 | return {}; |
1022 | } |
1023 | |
1024 | OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { |
1025 | // Fold consecutive concats on the same axis into a single op. |
1026 | // Keep track of the operands so we are able to construct a new concat |
1027 | // later. Conservatively assume that we double the number of operands when |
1028 | // folding |
1029 | SmallVector<Value, 8> concatOperands; |
1030 | concatOperands.reserve(2 * getNumOperands()); |
1031 | |
1032 | // Find all operands that are foldable concats |
1033 | bool foundFoldableConcat = false; |
1034 | for (Value operand : getOperands()) { |
1035 | concatOperands.emplace_back(operand); |
1036 | |
1037 | auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp()); |
1038 | if (!producer) |
1039 | continue; |
1040 | |
1041 | // Not foldable if axes are not the same |
1042 | if (getAxis() != producer.getAxis()) |
1043 | continue; |
1044 | |
1045 | // Replace the original operand with all incoming operands |
1046 | foundFoldableConcat = true; |
1047 | concatOperands.pop_back(); |
1048 | llvm::append_range(concatOperands, producer->getOperands()); |
1049 | } |
1050 | |
1051 | if (!foundFoldableConcat) |
1052 | return {}; |
1053 | |
1054 | getOperation()->setOperands(concatOperands); |
1055 | return getResult(); |
1056 | } |
1057 | |
1058 | OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) { |
1059 | auto input = adaptor.getInput1(); |
1060 | |
1061 | auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input); |
1062 | // Fold splat inputs only. |
1063 | if (!inputAttr || !inputAttr.isSplat()) |
1064 | return {}; |
1065 | |
1066 | auto shapeType = llvm::cast<ShapedType>(getType()); |
1067 | if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) { |
1068 | auto floatVal = inputAttr.getSplatValue<APFloat>(); |
1069 | return DenseElementsAttr::get(shapeType, |
1070 | ReciprocalOp::calcOneElement(floatVal)); |
1071 | } |
1072 | |
1073 | return {}; |
1074 | } |
1075 | |