1//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/Index/IR/IndexOps.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/Math/IR/Math.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/Dialect/Tensor/Utils/Utils.h"
22#include "mlir/Dialect/Tosa/IR/TosaOps.h"
23#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
24#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
25#include "mlir/Dialect/Utils/StaticValueUtils.h"
26#include "mlir/IR/ImplicitLocOpBuilder.h"
27#include "mlir/IR/Matchers.h"
28#include "mlir/IR/OpDefinition.h"
29#include "mlir/IR/PatternMatch.h"
30#include "mlir/Transforms/DialectConversion.h"
31#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32#include "llvm/ADT/STLExtras.h"
33#include "llvm/ADT/Sequence.h"
34
35#include <numeric>
36
37using namespace mlir;
38using namespace mlir::tosa;
39
40template <typename T>
41static arith::ConstantOp
42createConstFromIntAttribute(Operation *op, const std::string &attrName,
43 Type requiredAttrType, OpBuilder &rewriter) {
44 auto castedN = static_cast<T>(
45 cast<IntegerAttr>(op->getAttr(name: attrName)).getValue().getSExtValue());
46 return rewriter.create<arith::ConstantOp>(
47 op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
48}
49
50static Value
51createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
52 ArrayRef<Type> resultTypes,
53 PatternRewriter &rewriter) {
54 Location loc = op->getLoc();
55 auto elementTy =
56 cast<ShapedType>(op->getOperand(0).getType()).getElementType();
57
58 // tosa::AbsOp
59 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
60 return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
61
62 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
63 auto zero = rewriter.create<arith::ConstantOp>(
64 loc, rewriter.getZeroAttr(elementTy));
65 auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
66 return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
67 }
68
69 // tosa::AddOp
70 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
71 return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
72
73 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
74 return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
75
76 // tosa::SubOp
77 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
78 return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
79
80 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
81 return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
82
83 // tosa::MulOp
84 if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
85 if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
86 (void)rewriter.notifyMatchFailure(arg&: op,
87 msg: "Cannot have shift value for float");
88 return nullptr;
89 }
90 return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
91 }
92
93 // tosa::DivOp
94 if (isa<tosa::DivOp>(op) && isa<IntegerType>(elementTy))
95 return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
96
97 // tosa::ReciprocalOp
98 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
99 auto one =
100 rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
101 return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
102 }
103
104 if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
105 Value a = args[0];
106 Value b = args[1];
107 auto shift =
108 cast<IntegerAttr>(op->getAttr(name: "shift")).getValue().getSExtValue();
109 if (shift > 0) {
110 auto shiftConst =
111 rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
112 if (!a.getType().isInteger(32))
113 a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
114
115 if (!b.getType().isInteger(32))
116 b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
117
118 auto result = rewriter.create<tosa::ApplyScaleOp>(
119 loc, rewriter.getI32Type(), a, b, shiftConst,
120 rewriter.getBoolAttr(false));
121
122 if (elementTy.isInteger(32))
123 return result;
124
125 return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
126 }
127
128 int aWidth = a.getType().getIntOrFloatBitWidth();
129 int bWidth = b.getType().getIntOrFloatBitWidth();
130 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
131
132 if (aWidth < cWidth)
133 a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
134 if (bWidth < cWidth)
135 b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
136
137 return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
138 }
139
140 // tosa::NegateOp
141 if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
142 return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
143
144 if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
145 !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
146 auto constant =
147 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
148 return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
149 }
150
151 if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
152 cast<tosa::NegateOp>(op).getQuantizationInfo()) {
153 auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
154 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
155 int64_t inZp = quantizationInfo.value().getInputZp();
156 int64_t outZp = quantizationInfo.value().getOutputZp();
157
158 // Compute the maximum value that can occur in the intermediate buffer.
159 int64_t zpAdd = inZp + outZp;
160 int64_t maxValue = APInt::getSignedMaxValue(numBits: inputBitWidth).getSExtValue() +
161 std::abs(i: zpAdd) + 1;
162
163 // Convert that maximum value into the maximum bitwidth needed to represent
164 // it. We assume 48-bit numbers may be supported further in the pipeline.
165 int intermediateBitWidth = 64;
166 if (maxValue <= APInt::getSignedMaxValue(numBits: 16).getSExtValue()) {
167 intermediateBitWidth = 16;
168 } else if (maxValue <= APInt::getSignedMaxValue(numBits: 32).getSExtValue()) {
169 intermediateBitWidth = 32;
170 } else if (maxValue <= APInt::getSignedMaxValue(numBits: 48).getSExtValue()) {
171 intermediateBitWidth = 48;
172 }
173
174 Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
175 Value zpAddValue = rewriter.create<arith::ConstantOp>(
176 loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
177
178 // The negation can be applied by doing:
179 // outputValue = inZp + outZp - inputValue
180 auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
181 auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
182
183 // Clamp to the negation range.
184 Value min = rewriter.create<arith::ConstantIntOp>(
185 location: loc, args: APInt::getSignedMinValue(numBits: inputBitWidth).getSExtValue(),
186 args&: intermediateType);
187 Value max = rewriter.create<arith::ConstantIntOp>(
188 location: loc, args: APInt::getSignedMaxValue(numBits: inputBitWidth).getSExtValue(),
189 args&: intermediateType);
190 auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
191
192 // Truncate to the final value.
193 return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
194 }
195
196 // tosa::BitwiseAndOp
197 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
198 return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
199
200 // tosa::BitwiseOrOp
201 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
202 return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
203
204 // tosa::BitwiseNotOp
205 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
206 auto allOnesAttr = rewriter.getIntegerAttr(
207 elementTy, APInt::getAllOnes(numBits: elementTy.getIntOrFloatBitWidth()));
208 auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
209 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
210 }
211
212 // tosa::BitwiseXOrOp
213 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
214 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
215
216 // tosa::LogicalLeftShiftOp
217 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
218 return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
219
220 // tosa::LogicalRightShiftOp
221 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
222 return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
223
224 // tosa::ArithmeticRightShiftOp
225 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
226 auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
227 auto round = cast<BoolAttr>(Val: op->getAttr(name: "round")).getValue();
228 if (!round) {
229 return result;
230 }
231
232 Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
233 auto one =
234 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
235 auto zero =
236 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
237 auto i1one =
238 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
239
240 // Checking that input2 != 0
241 auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
242 loc, arith::CmpIPredicate::sgt, args[1], zero);
243
244 // Checking for the last bit of input1 to be 1
245 auto subtract =
246 rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
247 auto shifted =
248 rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
249 ->getResults();
250 auto truncated =
251 rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
252 auto isInputOdd =
253 rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
254
255 auto shouldRound = rewriter.create<arith::AndIOp>(
256 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
257 auto extended =
258 rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
259 return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
260 }
261
262 // tosa::ClzOp
263 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
264 return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
265 }
266
267 // tosa::LogicalAnd
268 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
269 return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
270
271 // tosa::LogicalNot
272 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
273 auto one = rewriter.create<arith::ConstantOp>(
274 loc, rewriter.getIntegerAttr(elementTy, 1));
275 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
276 }
277
278 // tosa::LogicalOr
279 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
280 return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
281
282 // tosa::LogicalXor
283 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
284 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
285
286 // tosa::PowOp
287 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
288 return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
289
290 // tosa::RsqrtOp
291 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
292 return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
293
294 // tosa::LogOp
295 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
296 return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
297
298 // tosa::ExpOp
299 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
300 return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
301
302 // tosa::TanhOp
303 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
304 return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
305
306 // tosa::ErfOp
307 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
308 return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
309
310 // tosa::GreaterOp
311 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
312 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
313 args[0], args[1]);
314
315 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
316 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
317 args[0], args[1]);
318
319 // tosa::GreaterEqualOp
320 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
321 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
322 args[0], args[1]);
323
324 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
325 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
326 args[0], args[1]);
327
328 // tosa::EqualOp
329 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
330 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
331 args[0], args[1]);
332
333 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
334 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
335 args[0], args[1]);
336
337 // tosa::SelectOp
338 if (isa<tosa::SelectOp>(op)) {
339 elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
340 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
341 return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
342 }
343
344 // tosa::MaximumOp
345 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
346 return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
347 }
348
349 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
350 return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
351 }
352
353 // tosa::MinimumOp
354 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
355 return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
356 }
357
358 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
359 return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
360 }
361
362 // tosa::CeilOp
363 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
364 return rewriter.create<math::CeilOp>(loc, resultTypes, args);
365
366 // tosa::FloorOp
367 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
368 return rewriter.create<math::FloorOp>(loc, resultTypes, args);
369
370 // tosa::ClampOp
371 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
372 bool losesInfo = false;
373 APFloat minApf = cast<FloatAttr>(op->getAttr(name: "min_fp")).getValue();
374 APFloat maxApf = cast<FloatAttr>(op->getAttr(name: "max_fp")).getValue();
375 minApf.convert(ToSemantics: cast<FloatType>(elementTy).getFloatSemantics(),
376 RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo);
377 maxApf.convert(ToSemantics: cast<FloatType>(elementTy).getFloatSemantics(),
378 RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo);
379 auto min = rewriter.create<arith::ConstantOp>(
380 loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
381 auto max = rewriter.create<arith::ConstantOp>(
382 loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
383 return clampFloatHelper(loc, args[0], min, max, rewriter);
384 }
385
386 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
387 auto intTy = cast<IntegerType>(elementTy);
388 int64_t min =
389 cast<IntegerAttr>(op->getAttr(name: "min_int")).getValue().getSExtValue();
390 int64_t max =
391 cast<IntegerAttr>(op->getAttr(name: "max_int")).getValue().getSExtValue();
392
393 if (intTy.isUnsignedInteger()) {
394 min = std::max(a: min, b: (int64_t)0);
395 max = std::min(
396 max,
397 APInt::getMaxValue(numBits: intTy.getIntOrFloatBitWidth()).getSExtValue());
398 } else {
399 min =
400 std::max(min, APInt::getSignedMinValue(numBits: intTy.getIntOrFloatBitWidth())
401 .getSExtValue());
402 max =
403 std::min(max, APInt::getSignedMaxValue(numBits: intTy.getIntOrFloatBitWidth())
404 .getSExtValue());
405 }
406
407 auto minVal = rewriter.create<arith::ConstantIntOp>(
408 loc, min, intTy.getIntOrFloatBitWidth());
409 auto maxVal = rewriter.create<arith::ConstantIntOp>(
410 loc, max, intTy.getIntOrFloatBitWidth());
411 return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
412 }
413
414 // tosa::SigmoidOp
415 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
416 auto one =
417 rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
418 auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
419 auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
420 auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
421 return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
422 }
423
424 // tosa::CastOp
425 if (isa<tosa::CastOp>(op)) {
426 Type srcTy = elementTy;
427 Type dstTy = resultTypes.front();
428 bool bitExtend =
429 srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
430
431 if (srcTy == dstTy)
432 return args.front();
433
434 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
435 return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
436 std::nullopt);
437
438 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
439 return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
440 std::nullopt);
441
442 // 1-bit integers need to be treated as signless.
443 if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
444 return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
445 std::nullopt);
446
447 if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
448 return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
449 std::nullopt);
450
451 // Unsigned integers need an unrealized cast so that they can be passed
452 // to UIToFP.
453 if (srcTy.isUnsignedInteger() && isa<FloatType>(Val: dstTy)) {
454 auto unrealizedCast =
455 rewriter
456 .create<UnrealizedConversionCastOp>(
457 loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
458 args[0])
459 .getResult(0);
460 return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
461 unrealizedCast);
462 }
463
464 // All other si-to-fp conversions should be handled by SIToFP.
465 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
466 return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
467 std::nullopt);
468
469 // Casting to boolean, floats need to only be checked as not-equal to zero.
470 if (isa<FloatType>(Val: srcTy) && dstTy.isInteger(width: 1)) {
471 Value zero = rewriter.create<arith::ConstantOp>(
472 loc, rewriter.getFloatAttr(srcTy, 0.0));
473 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
474 args.front(), zero);
475 }
476
477 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
478 auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
479
480 const auto &fltSemantics = cast<FloatType>(Val&: srcTy).getFloatSemantics();
481 // Check whether neither int min nor int max can be represented in the
482 // input floating-point type due to too short exponent range.
483 if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
484 APFloat::semanticsMaxExponent(fltSemantics)) {
485 // Use cmp + select to replace infinites by int min / int max. Other
486 // integral values can be represented in the integer space.
487 auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
488 auto posInf = rewriter.create<arith::ConstantOp>(
489 loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
490 APFloat::getInf(fltSemantics)));
491 auto negInf = rewriter.create<arith::ConstantOp>(
492 loc, rewriter.getFloatAttr(
493 getElementTypeOrSelf(srcTy),
494 APFloat::getInf(fltSemantics, /*Negative=*/true)));
495 auto overflow = rewriter.create<arith::CmpFOp>(
496 loc, arith::CmpFPredicate::UEQ, rounded, posInf);
497 auto underflow = rewriter.create<arith::CmpFOp>(
498 loc, arith::CmpFPredicate::UEQ, rounded, negInf);
499 auto intMin = rewriter.create<arith::ConstantOp>(
500 loc, rewriter.getIntegerAttr(
501 getElementTypeOrSelf(dstTy),
502 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
503 auto intMax = rewriter.create<arith::ConstantOp>(
504 loc, rewriter.getIntegerAttr(
505 getElementTypeOrSelf(dstTy),
506 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
507 auto maxClamped =
508 rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
509 return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
510 maxClamped);
511 }
512
513 auto intMinFP = rewriter.create<arith::ConstantOp>(
514 loc, rewriter.getFloatAttr(
515 getElementTypeOrSelf(srcTy),
516 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
517 .getSExtValue()));
518
519 // Check whether the mantissa has enough bits to represent int max.
520 if (cast<FloatType>(Val&: srcTy).getFPMantissaWidth() >=
521 dstTy.getIntOrFloatBitWidth() - 1) {
522 // Int min can also be represented since it is a power of two and thus
523 // consists of a single leading bit. Therefore we can clamp the input
524 // in the floating-point domain.
525
526 auto intMaxFP = rewriter.create<arith::ConstantOp>(
527 loc, rewriter.getFloatAttr(
528 getElementTypeOrSelf(srcTy),
529 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
530 .getSExtValue()));
531
532 Value clamped =
533 clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
534 return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
535 }
536
537 // Due to earlier check we know exponant range is big enough to represent
538 // int min. We can therefore rely on int max + 1 being representable as
539 // well because it's just int min with a positive sign. So clamp the min
540 // value and compare against that to select the max int value if needed.
541 auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
542 loc, rewriter.getFloatAttr(
543 getElementTypeOrSelf(srcTy),
544 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
545 .getSExtValue() +
546 1));
547
548 auto intMax = rewriter.create<arith::ConstantOp>(
549 loc, rewriter.getIntegerAttr(
550 getElementTypeOrSelf(dstTy),
551 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
552 auto minClampedFP =
553 rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
554 auto minClamped =
555 rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
556 auto overflow = rewriter.create<arith::CmpFOp>(
557 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
558 return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
559 minClamped);
560 }
561
562 // Casting to boolean, integers need to only be checked as not-equal to
563 // zero.
564 if (isa<IntegerType>(Val: srcTy) && dstTy.isInteger(width: 1)) {
565 Value zero = rewriter.create<arith::ConstantIntOp>(
566 location: loc, args: 0, args: srcTy.getIntOrFloatBitWidth());
567 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
568 args.front(), zero);
569 }
570
571 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
572 return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
573 std::nullopt);
574
575 if (isa<IntegerType>(Val: srcTy) && isa<IntegerType>(Val: dstTy) && !bitExtend) {
576 return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
577 }
578 }
579
580 (void)rewriter.notifyMatchFailure(
581 arg&: op, msg: "unhandled op for linalg body calculation for elementwise op");
582 return nullptr;
583}
584
585static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
586 int64_t rank) {
587 // No need to expand if we are already at the desired rank
588 auto shapedType = dyn_cast<ShapedType>(tensor.getType());
589 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
590 int64_t numExtraDims = rank - shapedType.getRank();
591 assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
592 if (!numExtraDims)
593 return tensor;
594
595 // Compute reassociation indices
596 SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
597 shapedType.getRank());
598 int64_t index = 0;
599 for (index = 0; index <= numExtraDims; index++)
600 reassociationIndices[0].push_back(Elt: index);
601 for (size_t position = 1; position < reassociationIndices.size(); position++)
602 reassociationIndices[position].push_back(Elt: index++);
603
604 // Compute result type
605 SmallVector<int64_t> resultShape;
606 for (index = 0; index < numExtraDims; index++)
607 resultShape.push_back(Elt: 1);
608 for (auto size : shapedType.getShape())
609 resultShape.push_back(size);
610 auto resultType =
611 RankedTensorType::get(resultShape, shapedType.getElementType());
612
613 // Emit 'tensor.expand_shape' op
614 return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
615 reassociationIndices);
616}
617
618static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
619 Location loc, Operation *operation) {
620 auto rank =
621 cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
622 return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
623 return expandRank(rewriter, loc, operand, rank);
624 });
625}
626
627using IndexPool = DenseMap<int64_t, Value>;
628
629// Emit an 'arith.constant' op for the given index if it has not been created
630// yet, or return an existing constant. This will prevent an excessive creation
631// of redundant constants, easing readability of emitted code for unit tests.
632static Value createIndex(PatternRewriter &rewriter, Location loc,
633 IndexPool &indexPool, int64_t index) {
634 auto [it, inserted] = indexPool.try_emplace(Key: index);
635 if (inserted)
636 it->second =
637 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
638 return it->second;
639}
640
641static Value getTensorDim(PatternRewriter &rewriter, Location loc,
642 IndexPool &indexPool, Value tensor, int64_t index) {
643 auto indexValue = createIndex(rewriter, loc, indexPool, index);
644 return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
645}
646
647static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc,
648 IndexPool &indexPool, Value tensor,
649 int64_t index) {
650 auto shapedType = dyn_cast<ShapedType>(tensor.getType());
651 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
652 assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
653 if (shapedType.isDynamicDim(index))
654 return getTensorDim(rewriter, loc, indexPool, tensor, index);
655 return rewriter.getIndexAttr(value: shapedType.getDimSize(index));
656}
657
658static bool operandsAndResultsRanked(Operation *operation) {
659 auto isRanked = [](Value value) {
660 return isa<RankedTensorType>(Val: value.getType());
661 };
662 return llvm::all_of(Range: operation->getOperands(), P: isRanked) &&
663 llvm::all_of(Range: operation->getResults(), P: isRanked);
664}
665
666// Compute the runtime dimension size for dimension 'dim' of the output by
667// inspecting input 'operands', all of which are expected to have the same rank.
668// This function returns a pair {targetSize, masterOperand}.
669//
670// The runtime size of the output dimension is returned either as a statically
671// computed attribute or as a runtime SSA value.
672//
673// If the target size was inferred directly from one dominating operand, that
674// operand is returned in 'masterOperand'. If the target size is inferred from
675// multiple operands, 'masterOperand' is set to nullptr.
676static std::pair<OpFoldResult, Value>
677computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
678 ValueRange operands, int64_t dim) {
679 // If any input operand contains a static size greater than 1 for this
680 // dimension, that is the target size. An occurrence of an additional static
681 // dimension greater than 1 with a different value is undefined behavior.
682 for (auto operand : operands) {
683 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
684 if (!ShapedType::isDynamic(size) && size > 1)
685 return {rewriter.getIndexAttr(value: size), operand};
686 }
687
688 // Filter operands with dynamic dimension
689 auto operandsWithDynamicDim =
690 llvm::to_vector(Range: llvm::make_filter_range(Range&: operands, Pred: [&](Value operand) {
691 return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
692 }));
693
694 // If no operand has a dynamic dimension, it means all sizes were 1
695 if (operandsWithDynamicDim.empty())
696 return {rewriter.getIndexAttr(1), operands.front()};
697
698 // Emit code that computes the runtime size for this dimension. If there is
699 // only one operand with a dynamic dimension, it is considered the master
700 // operand that determines the runtime size of the output dimension.
701 auto targetSize =
702 getTensorDim(rewriter, loc, indexPool, tensor: operandsWithDynamicDim[0], index: dim);
703 if (operandsWithDynamicDim.size() == 1)
704 return {targetSize, operandsWithDynamicDim[0]};
705
706 // Calculate maximum size among all dynamic dimensions
707 for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
708 auto nextSize =
709 getTensorDim(rewriter, loc, indexPool, tensor: operandsWithDynamicDim[i], index: dim);
710 targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
711 }
712 return {targetSize, nullptr};
713}
714
715// Compute the runtime output size for all dimensions. This function returns
716// a pair {targetShape, masterOperands}.
717static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
718computeTargetShape(PatternRewriter &rewriter, Location loc,
719 IndexPool &indexPool, ValueRange operands) {
720 assert(!operands.empty());
721 auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
722 SmallVector<OpFoldResult> targetShape;
723 SmallVector<Value> masterOperands;
724 for (auto dim : llvm::seq<int64_t>(0, rank)) {
725 auto [targetSize, masterOperand] =
726 computeTargetSize(rewriter, loc, indexPool, operands, dim);
727 targetShape.push_back(targetSize);
728 masterOperands.push_back(masterOperand);
729 }
730 return {targetShape, masterOperands};
731}
732
733static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
734 IndexPool &indexPool, Value operand,
735 int64_t dim, OpFoldResult targetSize,
736 Value masterOperand) {
737 // Nothing to do if this is a static dimension
738 auto rankedTensorType = cast<RankedTensorType>(operand.getType());
739 if (!rankedTensorType.isDynamicDim(dim))
740 return operand;
741
742 // If the target size for this dimension was directly inferred by only taking
743 // this operand into account, there is no need to broadcast. This is an
744 // optimization that will prevent redundant control flow, and constitutes the
745 // main motivation for tracking "master operands".
746 if (operand == masterOperand)
747 return operand;
748
749 // Affine maps for 'linalg.generic' op
750 auto rank = rankedTensorType.getRank();
751 SmallVector<AffineExpr> affineExprs;
752 for (auto index : llvm::seq<int64_t>(0, rank)) {
753 auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
754 : rewriter.getAffineDimExpr(index);
755 affineExprs.push_back(affineExpr);
756 }
757 auto broadcastAffineMap =
758 AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
759 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: rank);
760 SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
761
762 // Check if broadcast is necessary
763 auto one = createIndex(rewriter, loc, indexPool, index: 1);
764 auto runtimeSize = getTensorDim(rewriter, loc, indexPool, tensor: operand, index: dim);
765 auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
766 loc, arith::CmpIPredicate::eq, runtimeSize, one);
767
768 // Emit 'then' region of 'scf.if'
769 auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
770 // It is not safe to cache constants across regions.
771 // New constants could potentially violate dominance requirements.
772 IndexPool localPool;
773
774 // Emit 'tensor.empty' op
775 SmallVector<OpFoldResult> outputTensorShape;
776 for (auto index : llvm::seq<int64_t>(0, rank)) {
777 auto size = index == dim ? targetSize
778 : getOrFoldTensorDim(rewriter, loc, localPool,
779 operand, index);
780 outputTensorShape.push_back(size);
781 }
782 Value outputTensor = opBuilder.create<tensor::EmptyOp>(
783 loc, outputTensorShape, rankedTensorType.getElementType());
784
785 // Emit 'linalg.generic' op
786 auto resultTensor =
787 opBuilder
788 .create<linalg::GenericOp>(
789 loc, outputTensor.getType(), operand, outputTensor, affineMaps,
790 getNParallelLoopsAttrs(rank),
791 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
792 // Emit 'linalg.yield' op
793 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
794 })
795 .getResult(0);
796
797 // Cast to original operand type if necessary
798 auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
799 loc, operand.getType(), resultTensor);
800
801 // Emit 'scf.yield' op
802 opBuilder.create<scf::YieldOp>(loc, castResultTensor);
803 };
804
805 // Emit 'else' region of 'scf.if'
806 auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
807 opBuilder.create<scf::YieldOp>(loc, operand);
808 };
809
810 // Emit 'scf.if' op
811 auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
812 emitThenRegion, emitElseRegion);
813 return ifOp.getResult(0);
814}
815
816static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
817 IndexPool &indexPool, Value operand,
818 ArrayRef<OpFoldResult> targetShape,
819 ArrayRef<Value> masterOperands) {
820 int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
821 assert((int64_t)targetShape.size() == rank);
822 assert((int64_t)masterOperands.size() == rank);
823 for (auto index : llvm::seq<int64_t>(0, rank))
824 operand =
825 broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
826 targetShape[index], masterOperands[index]);
827 return operand;
828}
829
830static SmallVector<Value>
831broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
832 IndexPool &indexPool, ValueRange operands,
833 ArrayRef<OpFoldResult> targetShape,
834 ArrayRef<Value> masterOperands) {
835 // No need to broadcast for unary operations
836 if (operands.size() == 1)
837 return operands;
838
839 // Broadcast dynamic dimensions operand by operand
840 return llvm::map_to_vector(C&: operands, F: [&](Value operand) {
841 return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
842 targetShape, masterOperands);
843 });
844}
845
846static LogicalResult
847emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
848 Operation *operation, ValueRange operands,
849 ArrayRef<OpFoldResult> targetShape) {
850 // Generate output tensor
851 auto resultType = cast<RankedTensorType>(operation->getResultTypes().front());
852 Value outputTensor = rewriter.create<tensor::EmptyOp>(
853 loc, targetShape, resultType.getElementType());
854
855 // Create affine maps. Input affine maps broadcast static dimensions of size
856 // 1. The output affine map is an identity map.
857 //
858 auto rank = resultType.getRank();
859 auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
860 auto shape = cast<ShapedType>(operand.getType()).getShape();
861 SmallVector<AffineExpr> affineExprs;
862 for (auto it : llvm::enumerate(shape)) {
863 auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0)
864 : rewriter.getAffineDimExpr(it.index());
865 affineExprs.push_back(affineExpr);
866 }
867 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
868 });
869 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank: rank));
870
871 // Emit 'linalg.generic' op
872 bool encounteredError = false;
873 auto linalgOp = rewriter.create<linalg::GenericOp>(
874 loc, outputTensor.getType(), operands, outputTensor, affineMaps,
875 getNParallelLoopsAttrs(rank),
876 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
877 Value opResult = createLinalgBodyCalculationForElementwiseOp(
878 operation, blockArgs.take_front(operation->getNumOperands()),
879 {resultType.getElementType()}, rewriter);
880 if (!opResult) {
881 encounteredError = true;
882 return;
883 }
884 opBuilder.create<linalg::YieldOp>(loc, opResult);
885 });
886 if (encounteredError)
887 return rewriter.notifyMatchFailure(
888 arg&: operation, msg: "unable to create linalg.generic body for elementwise op");
889
890 // Cast 'linalg.generic' result into original result type if needed
891 auto castResult = rewriter.createOrFold<tensor::CastOp>(
892 loc, resultType, linalgOp->getResult(0));
893 rewriter.replaceOp(operation, castResult);
894 return success();
895}
896
897static LogicalResult
898elementwiseMatchAndRewriteHelper(Operation *operation,
899 PatternRewriter &rewriter) {
900
901 // Collect op properties
902 assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
903 assert(operation->getNumOperands() >= 1 &&
904 "elementwise op expects at least 1 operand");
905 if (!operandsAndResultsRanked(operation))
906 return rewriter.notifyMatchFailure(arg&: operation,
907 msg: "Unranked tensors not supported");
908
909 // Lower operation
910 IndexPool indexPool;
911 auto loc = operation->getLoc();
912 auto expandedOperands = expandInputRanks(rewriter, loc, operation);
913 auto [targetShape, masterOperands] =
914 computeTargetShape(rewriter, loc, indexPool, operands: expandedOperands);
915 auto broadcastOperands = broadcastDynamicDimensions(
916 rewriter, loc, indexPool, operands: expandedOperands, targetShape, masterOperands);
917 return emitElementwiseComputation(rewriter, loc, operation, operands: broadcastOperands,
918 targetShape);
919}
920
921// Returns the constant initial value for a given reduction operation. The
922// attribute type varies depending on the element type required.
923static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
924 PatternRewriter &rewriter) {
925 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
926 return rewriter.getFloatAttr(elementTy, 0.0);
927
928 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
929 return rewriter.getIntegerAttr(elementTy, 0);
930
931 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
932 return rewriter.getFloatAttr(elementTy, 1.0);
933
934 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
935 return rewriter.getIntegerAttr(elementTy, 1);
936
937 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
938 return rewriter.getFloatAttr(
939 elementTy, APFloat::getLargest(
940 Sem: cast<FloatType>(Val&: elementTy).getFloatSemantics(), Negative: false));
941
942 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
943 return rewriter.getIntegerAttr(
944 elementTy, APInt::getSignedMaxValue(numBits: elementTy.getIntOrFloatBitWidth()));
945
946 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
947 return rewriter.getFloatAttr(
948 elementTy, APFloat::getLargest(
949 Sem: cast<FloatType>(Val&: elementTy).getFloatSemantics(), Negative: true));
950
951 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
952 return rewriter.getIntegerAttr(
953 elementTy, APInt::getSignedMinValue(numBits: elementTy.getIntOrFloatBitWidth()));
954
955 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
956 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(numBits: 1));
957
958 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
959 return rewriter.getIntegerAttr(elementTy, APInt::getZero(numBits: 1));
960
961 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
962 return rewriter.getFloatAttr(
963 elementTy, APFloat::getLargest(
964 Sem: cast<FloatType>(Val&: elementTy).getFloatSemantics(), Negative: true));
965
966 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
967 return rewriter.getIntegerAttr(
968 elementTy, APInt::getSignedMinValue(numBits: elementTy.getIntOrFloatBitWidth()));
969
970 return {};
971}
972
973// Creates the body calculation for a reduction. The operations vary depending
974// on the input type.
975static Value createLinalgBodyCalculationForReduceOp(Operation *op,
976 ValueRange args,
977 Type elementTy,
978 PatternRewriter &rewriter) {
979 Location loc = op->getLoc();
980 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
981 return rewriter.create<arith::AddFOp>(loc, args);
982 }
983
984 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
985 return rewriter.create<arith::AddIOp>(loc, args);
986 }
987
988 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
989 return rewriter.create<arith::MulFOp>(loc, args);
990 }
991
992 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
993 return rewriter.create<arith::MulIOp>(loc, args);
994 }
995
996 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
997 return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
998 }
999
1000 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1001 return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1002 }
1003
1004 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1005 return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1006 }
1007
1008 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1009 return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
1010 }
1011
1012 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1013 return rewriter.create<arith::AndIOp>(loc, args);
1014
1015 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1016 return rewriter.create<arith::OrIOp>(loc, args);
1017
1018 return {};
1019}
1020
1021// Performs the match and rewrite for reduction operations. This includes
1022// declaring a correctly sized initial value, and the linalg.generic operation
1023// that reduces across the specified axis.
1024static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1025 PatternRewriter &rewriter) {
1026 auto loc = op->getLoc();
1027 auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
1028 auto resultTy = cast<ShapedType>(op->getResult(0).getType());
1029 auto elementTy = resultTy.getElementType();
1030 Value input = op->getOperand(idx: 0);
1031
1032 SmallVector<int64_t> reduceShape;
1033 SmallVector<Value> dynDims;
1034 for (unsigned i = 0; i < inputTy.getRank(); i++) {
1035 if (axis != i) {
1036 reduceShape.push_back(Elt: inputTy.getDimSize(i));
1037 if (inputTy.isDynamicDim(i))
1038 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1039 }
1040 }
1041
1042 // First fill the output buffer with the init value.
1043 auto emptyTensor =
1044 rewriter
1045 .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1046 dynDims)
1047 .getResult();
1048
1049 auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1050 if (!fillValueAttr)
1051 return rewriter.notifyMatchFailure(
1052 arg&: op, msg: "No initial value found for reduction operation");
1053
1054 auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1055 auto filledTensor = rewriter
1056 .create<linalg::FillOp>(loc, ValueRange{fillValue},
1057 ValueRange{emptyTensor})
1058 .result();
1059
1060 bool didEncounterError = false;
1061 auto linalgOp = rewriter.create<linalg::ReduceOp>(
1062 loc, input, filledTensor, axis,
1063 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1064 auto result = createLinalgBodyCalculationForReduceOp(
1065 op, blockArgs, elementTy, rewriter);
1066 if (result)
1067 didEncounterError = true;
1068
1069 nestedBuilder.create<linalg::YieldOp>(loc, result);
1070 });
1071
1072 if (!didEncounterError)
1073 return rewriter.notifyMatchFailure(
1074 arg&: op, msg: "unable to create linalg.generic body for reduce op");
1075
1076 SmallVector<ReassociationExprs, 4> reassociationMap;
1077 uint64_t expandInputRank =
1078 cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1079 reassociationMap.resize(N: expandInputRank);
1080
1081 for (uint64_t i = 0; i < expandInputRank; i++) {
1082 int32_t dimToPush = i > axis ? i + 1 : i;
1083 reassociationMap[i].push_back(Elt: rewriter.getAffineDimExpr(position: dimToPush));
1084 }
1085
1086 if (expandInputRank != 0) {
1087 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1088 reassociationMap[expandedDim].push_back(
1089 Elt: rewriter.getAffineDimExpr(position: expandedDim + 1));
1090 }
1091
1092 // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1093 // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1094 // not have access to such information. This matters when handling dynamically
1095 // sized tensors.
1096 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1097 op, resultTy, linalgOp.getResults()[0], reassociationMap);
1098 return success();
1099}
1100
1101namespace {
1102
1103template <typename SrcOp>
1104class PointwiseConverter : public OpRewritePattern<SrcOp> {
1105public:
1106 using OpRewritePattern<SrcOp>::OpRewritePattern;
1107
1108 LogicalResult matchAndRewrite(SrcOp op,
1109 PatternRewriter &rewriter) const final {
1110 return elementwiseMatchAndRewriteHelper(op, rewriter);
1111 }
1112};
1113
1114class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1115public:
1116 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1117
1118 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1119 PatternRewriter &rewriter) const final {
1120 auto loc = op.getLoc();
1121 auto input = op.getInput();
1122 auto inputTy = cast<ShapedType>(op.getInput().getType());
1123 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1124 unsigned rank = inputTy.getRank();
1125
1126 // This is an illegal configuration. terminate and log an error
1127 if (op.getDoubleRound() && !op.getScale32())
1128 return rewriter.notifyMatchFailure(
1129 op, "tosa.rescale requires scale32 for double_round to be true");
1130
1131 SmallVector<Value> dynDims;
1132 for (int i = 0; i < outputTy.getRank(); i++) {
1133 if (outputTy.isDynamicDim(i)) {
1134 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1135 }
1136 }
1137
1138 // The shift and multiplier values.
1139 SmallVector<int32_t> multiplierValues(op.getMultiplier());
1140 SmallVector<int8_t> shiftValues(op.getShift());
1141
1142 // If we shift by more than the bitwidth, this just sets to 0.
1143 for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1144 if (shiftValues[i] > 63) {
1145 shiftValues[i] = 0;
1146 multiplierValues[i] = 0;
1147 }
1148 }
1149
1150 // Double round only occurs if shift is greater than 31, check that this
1151 // is ever true.
1152 bool doubleRound =
1153 op.getDoubleRound() &&
1154 llvm::any_of(Range&: shiftValues, P: [](int32_t v) { return v > 31; });
1155
1156 SmallVector<AffineMap> indexingMaps = {
1157 rewriter.getMultiDimIdentityMap(rank)};
1158 SmallVector<Value, 4> genericInputs = {input};
1159
1160 // If we are rescaling per-channel then we need to store the multiplier
1161 // values in a buffer.
1162 Value multiplierConstant;
1163 int64_t multiplierArg = 0;
1164 if (multiplierValues.size() == 1) {
1165 multiplierConstant = rewriter.create<arith::ConstantOp>(
1166 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1167 } else {
1168 SmallVector<AffineExpr, 2> multiplierExprs{
1169 rewriter.getAffineDimExpr(position: rank - 1)};
1170 auto multiplierType =
1171 RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1172 rewriter.getI32Type());
1173 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1174 loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1175
1176 indexingMaps.push_back(Elt: AffineMap::get(/*dimCount=*/rank,
1177 /*symbolCount=*/0, results: multiplierExprs,
1178 context: rewriter.getContext()));
1179
1180 multiplierArg = indexingMaps.size() - 1;
1181 }
1182
1183 // If we are rescaling per-channel then we need to store the shift
1184 // values in a buffer.
1185 Value shiftConstant;
1186 int64_t shiftArg = 0;
1187 if (shiftValues.size() == 1) {
1188 shiftConstant = rewriter.create<arith::ConstantOp>(
1189 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1190 } else {
1191 SmallVector<AffineExpr, 2> shiftExprs = {
1192 rewriter.getAffineDimExpr(position: rank - 1)};
1193 auto shiftType =
1194 RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1195 rewriter.getIntegerType(8));
1196 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1197 loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1198 indexingMaps.push_back(Elt: AffineMap::get(/*dimCount=*/rank,
1199 /*symbolCount=*/0, results: shiftExprs,
1200 context: rewriter.getContext()));
1201 shiftArg = indexingMaps.size() - 1;
1202 }
1203
1204 // Indexing maps for output values.
1205 indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank));
1206
1207 // Construct the indexing maps needed for linalg.generic ops.
1208 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1209 loc, outputTy.getShape(), outputTy.getElementType(),
1210 ArrayRef<Value>({dynDims}));
1211
1212 auto linalgOp = rewriter.create<linalg::GenericOp>(
1213 loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1214 getNParallelLoopsAttrs(rank),
1215 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1216 ValueRange blockArgs) {
1217 Value value = blockArgs[0];
1218 Type valueTy = value.getType();
1219
1220 // For now we do all of our math in 64-bit. This is not optimal but
1221 // should be correct for now, consider computing correct bit depth
1222 // later.
1223 int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1224
1225 auto inputZp = createConstFromIntAttribute<int32_t>(
1226 op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1227 nestedBuilder);
1228 auto outputZp = createConstFromIntAttribute<int32_t>(
1229 op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1230
1231 Value multiplier = multiplierConstant ? multiplierConstant
1232 : blockArgs[multiplierArg];
1233 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1234
1235 if (valueTy.getIntOrFloatBitWidth() < 32) {
1236 if (valueTy.isUnsignedInteger()) {
1237 value = nestedBuilder
1238 .create<UnrealizedConversionCastOp>(
1239 nestedLoc,
1240 nestedBuilder.getIntegerType(
1241 valueTy.getIntOrFloatBitWidth()),
1242 value)
1243 .getResult(0);
1244 value = nestedBuilder.create<arith::ExtUIOp>(
1245 nestedLoc, nestedBuilder.getI32Type(), value);
1246 } else {
1247 value = nestedBuilder.create<arith::ExtSIOp>(
1248 nestedLoc, nestedBuilder.getI32Type(), value);
1249 }
1250 }
1251
1252 value =
1253 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1254
1255 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1256 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1257 nestedBuilder.getBoolAttr(doubleRound));
1258
1259 // Move to the new zero-point.
1260 value =
1261 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1262
1263 // Saturate to the output size.
1264 IntegerType outIntType =
1265 cast<IntegerType>(blockArgs.back().getType());
1266 unsigned outBitWidth = outIntType.getWidth();
1267
1268 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1269 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1270
1271 // Unsigned integers have a difference output value.
1272 if (outIntType.isUnsignedInteger()) {
1273 intMin = 0;
1274 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1275 }
1276
1277 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1278 loc, nestedBuilder.getI32IntegerAttr(intMin));
1279 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1280 loc, nestedBuilder.getI32IntegerAttr(intMax));
1281
1282 value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1283 nestedBuilder);
1284
1285 if (outIntType.getWidth() < 32) {
1286 value = nestedBuilder.create<arith::TruncIOp>(
1287 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1288 value);
1289
1290 if (outIntType.isUnsignedInteger()) {
1291 value = nestedBuilder
1292 .create<UnrealizedConversionCastOp>(nestedLoc,
1293 outIntType, value)
1294 .getResult(0);
1295 }
1296 }
1297
1298 nestedBuilder.create<linalg::YieldOp>(loc, value);
1299 });
1300
1301 rewriter.replaceOp(op, linalgOp->getResults());
1302 return success();
1303 }
1304};
1305
1306// Handle the resize case where the input is a 1x1 image. This case
1307// can entirely avoiding having extract operations which target much
1308// more difficult to optimize away.
1309class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1310public:
1311 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1312
1313 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1314 PatternRewriter &rewriter) const final {
1315 Location loc = op.getLoc();
1316 ImplicitLocOpBuilder builder(loc, rewriter);
1317 auto input = op.getInput();
1318 auto inputTy = cast<RankedTensorType>(input.getType());
1319 auto resultTy = cast<RankedTensorType>(op.getType());
1320 const bool isBilinear = op.getMode() == "BILINEAR";
1321
1322 auto inputH = inputTy.getDimSize(1);
1323 auto inputW = inputTy.getDimSize(2);
1324 auto outputH = resultTy.getDimSize(1);
1325 auto outputW = resultTy.getDimSize(2);
1326
1327 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1328 return rewriter.notifyMatchFailure(
1329 op, "tosa.resize is not a pure 1x1->1x1 image operation");
1330
1331 // TODO(suderman): These string values should be declared the TOSA dialect.
1332 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1333 return rewriter.notifyMatchFailure(
1334 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1335
1336 if (inputTy == resultTy) {
1337 rewriter.replaceOp(op, input);
1338 return success();
1339 }
1340
1341 ArrayRef<int64_t> scale = op.getScale();
1342
1343 // Collapse the unit width and height away.
1344 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1345 reassociationMap[0].push_back(Elt: builder.getAffineDimExpr(position: 0));
1346 reassociationMap[1].push_back(Elt: builder.getAffineDimExpr(position: 1));
1347 reassociationMap[1].push_back(Elt: builder.getAffineDimExpr(position: 2));
1348 reassociationMap[1].push_back(Elt: builder.getAffineDimExpr(position: 3));
1349
1350 auto collapseTy =
1351 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1352 inputTy.getElementType());
1353 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1354 reassociationMap);
1355
1356 // Get any dynamic shapes that appear in the input format.
1357 llvm::SmallVector<Value> outputDynSize;
1358 if (inputTy.isDynamicDim(0))
1359 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1360 if (inputTy.isDynamicDim(3))
1361 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1362
1363 // Generate the elementwise operation for casting scaling the input value.
1364 auto genericTy = collapseTy.clone(resultTy.getElementType());
1365 Value empty = builder.create<tensor::EmptyOp>(
1366 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1367 auto genericMap = rewriter.getMultiDimIdentityMap(rank: genericTy.getRank());
1368 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1369 utils::IteratorType::parallel);
1370
1371 auto generic = builder.create<linalg::GenericOp>(
1372 genericTy, ValueRange{collapse}, ValueRange{empty},
1373 ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1374 [=](OpBuilder &b, Location loc, ValueRange args) {
1375 Value value = args[0];
1376 // This is the quantized case.
1377 if (inputTy.getElementType() != resultTy.getElementType()) {
1378 value =
1379 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1380
1381 if (isBilinear && scale[0] != 0) {
1382 Value scaleY = b.create<arith::ConstantOp>(
1383 loc, b.getI32IntegerAttr(scale[0]));
1384 value = b.create<arith::MulIOp>(loc, value, scaleY);
1385 }
1386
1387 if (isBilinear && scale[2] != 0) {
1388 Value scaleX = b.create<arith::ConstantOp>(
1389 loc, b.getI32IntegerAttr(scale[2]));
1390 value = b.create<arith::MulIOp>(loc, value, scaleX);
1391 }
1392 }
1393
1394 b.create<linalg::YieldOp>(loc, value);
1395 });
1396
1397 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1398 op, resultTy, generic.getResults()[0], reassociationMap);
1399 return success();
1400 }
1401};
1402
1403// TOSA resize with width or height of 1 may be broadcasted to a wider
1404// dimension. This is done by materializing a new tosa.resize without
1405// the broadcasting behavior, and an explicit broadcast afterwards.
1406class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1407public:
1408 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1409
1410 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1411 PatternRewriter &rewriter) const final {
1412 Location loc = op.getLoc();
1413 ImplicitLocOpBuilder builder(loc, rewriter);
1414 auto input = op.getInput();
1415 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1416 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1417
1418 if (!inputTy || !resultTy)
1419 return rewriter.notifyMatchFailure(op,
1420 "requires ranked input/output types");
1421
1422 auto batch = inputTy.getDimSize(0);
1423 auto channels = inputTy.getDimSize(3);
1424 auto inputH = inputTy.getDimSize(1);
1425 auto inputW = inputTy.getDimSize(2);
1426 auto outputH = resultTy.getDimSize(1);
1427 auto outputW = resultTy.getDimSize(2);
1428
1429 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1430 return rewriter.notifyMatchFailure(
1431 op, "tosa.resize has no broadcasting behavior");
1432
1433 // For any dimension that is broadcastable we generate a width of 1
1434 // on the output.
1435 llvm::SmallVector<int64_t> resizeShape;
1436 resizeShape.push_back(Elt: batch);
1437 resizeShape.push_back(Elt: inputH == 1 ? 1 : outputH);
1438 resizeShape.push_back(Elt: inputW == 1 ? 1 : outputW);
1439 resizeShape.push_back(Elt: channels);
1440
1441 auto resizeTy = resultTy.clone(resizeShape);
1442 auto resize =
1443 builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1444
1445 // Collapse an unit result dims.
1446 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1447 reassociationMap[0].push_back(Elt: builder.getAffineDimExpr(position: 0));
1448 reassociationMap.back().push_back(Elt: builder.getAffineDimExpr(position: 1));
1449 if (inputH != 1)
1450 reassociationMap.push_back(Elt: {});
1451 reassociationMap.back().push_back(Elt: builder.getAffineDimExpr(position: 2));
1452 if (inputW != 1)
1453 reassociationMap.push_back(Elt: {});
1454 reassociationMap.back().push_back(Elt: builder.getAffineDimExpr(position: 3));
1455
1456 llvm::SmallVector<int64_t> collapseShape{batch};
1457 if (inputH != 1)
1458 collapseShape.push_back(Elt: outputH);
1459 if (inputW != 1)
1460 collapseShape.push_back(Elt: outputW);
1461 collapseShape.push_back(Elt: channels);
1462
1463 auto collapseTy = resultTy.clone(collapseShape);
1464 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1465 reassociationMap);
1466
1467 // Broadcast the collapsed shape to the output result.
1468 llvm::SmallVector<Value> outputDynSize;
1469 if (inputTy.isDynamicDim(0))
1470 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1471 if (inputTy.isDynamicDim(3))
1472 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1473
1474 SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1475 utils::IteratorType::parallel);
1476 Value empty = builder.create<tensor::EmptyOp>(
1477 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1478
1479 SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(position: 0)};
1480 if (inputH != 1)
1481 inputExprs.push_back(Elt: rewriter.getAffineDimExpr(position: 1));
1482 if (inputW != 1)
1483 inputExprs.push_back(Elt: rewriter.getAffineDimExpr(position: 2));
1484 inputExprs.push_back(Elt: rewriter.getAffineDimExpr(position: 3));
1485
1486 auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1487 inputExprs, rewriter.getContext());
1488
1489 auto outputMap = rewriter.getMultiDimIdentityMap(rank: resultTy.getRank());
1490 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1491 op, resultTy, ValueRange{collapse}, ValueRange{empty},
1492 ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1493 [=](OpBuilder &b, Location loc, ValueRange args) {
1494 Value value = args[0];
1495 b.create<linalg::YieldOp>(loc, value);
1496 });
1497
1498 return success();
1499 }
1500};
1501
1502class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1503public:
1504 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1505
1506 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1507 PatternRewriter &rewriter) const final {
1508 Location loc = op.getLoc();
1509 ImplicitLocOpBuilder b(loc, rewriter);
1510 auto input = op.getInput();
1511 auto inputTy = cast<ShapedType>(input.getType());
1512 auto resultTy = cast<ShapedType>(op.getType());
1513 auto resultETy = resultTy.getElementType();
1514
1515 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1516 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1517
1518 auto imageH = inputTy.getShape()[1];
1519 auto imageW = inputTy.getShape()[2];
1520
1521 auto dynamicDimsOr =
1522 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1523 if (!dynamicDimsOr.has_value())
1524 return rewriter.notifyMatchFailure(
1525 op, "unable to get dynamic dimensions of tosa.resize");
1526
1527 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1528 return rewriter.notifyMatchFailure(
1529 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1530
1531 SmallVector<AffineMap, 2> affineMaps = {
1532 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())};
1533 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1534 *dynamicDimsOr);
1535 auto genericOp = b.create<linalg::GenericOp>(
1536 resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1537 getNParallelLoopsAttrs(resultTy.getRank()));
1538 Value resize = genericOp.getResult(0);
1539
1540 {
1541 OpBuilder::InsertionGuard regionGuard(b);
1542 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1543 TypeRange({resultETy}), loc);
1544 Value batch = b.create<linalg::IndexOp>(0);
1545 Value y = b.create<linalg::IndexOp>(1);
1546 Value x = b.create<linalg::IndexOp>(2);
1547 Value channel = b.create<linalg::IndexOp>(3);
1548
1549 Value zeroI32 =
1550 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1551 Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1552 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1553 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1554
1555 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1556 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1557
1558 ArrayRef<int64_t> offset = op.getOffset();
1559 ArrayRef<int64_t> border = op.getBorder();
1560 ArrayRef<int64_t> scale = op.getScale();
1561
1562 Value yScaleN, yScaleD, xScaleN, xScaleD;
1563 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1564 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1565 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1566 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1567
1568 Value yOffset, xOffset, yBorder, xBorder;
1569 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1570 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1571 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1572 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1573
1574 // Compute the ix and dx values for both the X and Y dimensions.
1575 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1576 Value scaleN, Value scaleD, Value offset,
1577 int size, ImplicitLocOpBuilder &b) {
1578 if (size == 1) {
1579 index = zeroI32;
1580 delta = zeroFp;
1581 return;
1582 }
1583 // x = x * scale_d + offset;
1584 // ix = floor(x / scale_n)
1585 Value val = b.create<arith::MulIOp>(in, scaleD);
1586 val = b.create<arith::AddIOp>(val, offset);
1587 index = b.create<arith::FloorDivSIOp>(val, scaleN);
1588
1589 // rx = x % scale_n
1590 // dx = rx / scale_n
1591 Value r = b.create<arith::RemSIOp>(val, scaleN);
1592 Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1593 Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1594 delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1595 };
1596
1597 // Compute the ix and dx values for the X and Y dimensions - int case.
1598 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1599 Value scaleN, Value scaleD, Value offset,
1600 int size, ImplicitLocOpBuilder &b) {
1601 if (size == 1) {
1602 index = zeroI32;
1603 delta = zeroI32;
1604 return;
1605 }
1606 // x = x * scale_d + offset;
1607 // ix = floor(x / scale_n)
1608 // dx = x - ix * scale_n;
1609 Value val = b.create<arith::MulIOp>(in, scaleD);
1610 val = b.create<arith::AddIOp>(val, offset);
1611 index = b.create<arith::DivSIOp>(val, scaleN);
1612 delta = b.create<arith::MulIOp>(index, scaleN);
1613 delta = b.create<arith::SubIOp>(val, delta);
1614 };
1615
1616 Value ix, iy, dx, dy;
1617 if (floatingPointMode) {
1618 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1619 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1620 } else {
1621 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1622 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1623 }
1624
1625 if (op.getMode() == "NEAREST_NEIGHBOR") {
1626 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1627
1628 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1629 Value max, int size,
1630 ImplicitLocOpBuilder &b) -> Value {
1631 if (size == 1) {
1632 return b.create<arith::ConstantIndexOp>(0);
1633 }
1634
1635 Value pred;
1636 if (floatingPointMode) {
1637 auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1638 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1639 } else {
1640 Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1641 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1642 dvalDouble, scale);
1643 }
1644
1645 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1646 val = b.create<arith::AddIOp>(val, offset);
1647 val = clampIntHelper(loc, arg: val, min: zeroI32, max, rewriter&: b);
1648 return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1649 };
1650
1651 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1652 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1653
1654 Value result = b.create<tensor::ExtractOp>(
1655 input, ValueRange{batch, iy, ix, channel});
1656
1657 b.create<linalg::YieldOp>(result);
1658 } else {
1659 // The mode here must be BILINEAR.
1660 assert(op.getMode() == "BILINEAR");
1661
1662 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1663
1664 auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1665 Value max, ImplicitLocOpBuilder &b) {
1666 val0 = in;
1667 val1 = b.create<arith::AddIOp>(val0, oneVal);
1668 val0 = clampIntHelper(loc, arg: val0, min: zeroI32, max, rewriter&: b);
1669 val1 = clampIntHelper(loc, arg: val1, min: zeroI32, max, rewriter&: b);
1670 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1671 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1672 };
1673
1674 // Linalg equivalent to the section below:
1675 // int16_t iy0 = apply_max(iy, 0);
1676 // int16_t iy1 = apply_min(iy + 1, IH - 1);
1677 // int16_t ix0 = apply_max(ix, 0);
1678 // int16_t ix1 = apply_min(ix + 1, IW - 1);
1679 Value x0, x1, y0, y1;
1680 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1681 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1682
1683 Value y0x0 = b.create<tensor::ExtractOp>(
1684 input, ValueRange{batch, y0, x0, channel});
1685 Value y0x1 = b.create<tensor::ExtractOp>(
1686 input, ValueRange{batch, y0, x1, channel});
1687 Value y1x0 = b.create<tensor::ExtractOp>(
1688 input, ValueRange{batch, y1, x0, channel});
1689 Value y1x1 = b.create<tensor::ExtractOp>(
1690 input, ValueRange{batch, y1, x1, channel});
1691
1692 if (floatingPointMode) {
1693 auto oneVal =
1694 b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1695 auto interpolate = [&](Value val0, Value val1, Value delta,
1696 int inputSize,
1697 ImplicitLocOpBuilder &b) -> Value {
1698 if (inputSize == 1)
1699 return val0;
1700 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1701 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1702 Value mul1 = b.create<arith::MulFOp>(val1, delta);
1703 return b.create<arith::AddFOp>(mul0, mul1);
1704 };
1705
1706 // Linalg equivalent to the section below:
1707 // topAcc = v00 * (unit_x - dx);
1708 // topAcc += v01 * dx;
1709 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1710
1711 // Linalg equivalent to the section below:
1712 // bottomAcc = v10 * (unit_x - dx);
1713 // bottomAcc += v11 * dx;
1714 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1715
1716 // Linalg equivalent to the section below:
1717 // result = topAcc * (unit_y - dy) + bottomAcc * dy
1718 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1719 b.create<linalg::YieldOp>(result);
1720 } else {
1721 // Perform in quantized space.
1722 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1723 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1724 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1725 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1726
1727 const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1728 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1729 dx = b.create<arith::ExtSIOp>(resultETy, dx);
1730 dy = b.create<arith::ExtSIOp>(resultETy, dy);
1731 }
1732
1733 Value yScaleNExt = yScaleN;
1734 Value xScaleNExt = xScaleN;
1735
1736 const int64_t scaleBitwidth =
1737 xScaleN.getType().getIntOrFloatBitWidth();
1738 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1739 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1740 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1741 }
1742
1743 auto interpolate = [](Value val0, Value val1, Value weight1,
1744 Value scale, int inputSize,
1745 ImplicitLocOpBuilder &b) -> Value {
1746 if (inputSize == 1)
1747 return b.create<arith::MulIOp>(val0, scale);
1748 Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1749 Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1750 Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1751 return b.create<arith::AddIOp>(mul0, mul1);
1752 };
1753
1754 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1755 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1756 Value result =
1757 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1758 b.create<linalg::YieldOp>(result);
1759 }
1760 }
1761 }
1762
1763 rewriter.replaceOp(op, resize);
1764 return success();
1765 }
1766};
1767
1768// At the codegen level any identity operations should be removed. Any cases
1769// where identity is load-bearing (e.g. cross device computation) should be
1770// handled before lowering to codegen.
1771template <typename SrcOp>
1772class IdentityNConverter : public OpRewritePattern<SrcOp> {
1773public:
1774 using OpRewritePattern<SrcOp>::OpRewritePattern;
1775
1776 LogicalResult matchAndRewrite(SrcOp op,
1777 PatternRewriter &rewriter) const final {
1778 rewriter.replaceOp(op, op.getOperation()->getOperands());
1779 return success();
1780 }
1781};
1782
1783template <typename SrcOp>
1784class ReduceConverter : public OpRewritePattern<SrcOp> {
1785public:
1786 using OpRewritePattern<SrcOp>::OpRewritePattern;
1787
1788 LogicalResult matchAndRewrite(SrcOp reduceOp,
1789 PatternRewriter &rewriter) const final {
1790 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
1791 }
1792};
1793
1794class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1795public:
1796 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
1797
1798 LogicalResult matchAndRewrite(tosa::ReverseOp op,
1799 PatternRewriter &rewriter) const final {
1800 auto loc = op.getLoc();
1801 Value input = op.getInput();
1802 auto inputTy = cast<ShapedType>(input.getType());
1803 auto resultTy = cast<ShapedType>(op.getType());
1804 auto axis = op.getAxis();
1805
1806 SmallVector<Value> dynDims;
1807 for (int i = 0; i < inputTy.getRank(); i++) {
1808 if (inputTy.isDynamicDim(i)) {
1809 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1810 }
1811 }
1812
1813 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1814
1815 // First fill the output buffer with the init value.
1816 auto emptyTensor = rewriter
1817 .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1818 inputTy.getElementType(),
1819 ArrayRef<Value>({dynDims}))
1820 .getResult();
1821 SmallVector<AffineMap, 2> affineMaps = {
1822 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())};
1823
1824 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1825 op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
1826 getNParallelLoopsAttrs(resultTy.getRank()),
1827 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1828 llvm::SmallVector<Value> indices;
1829 for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1830 Value index =
1831 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1832 if (i == axis) {
1833 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1834 auto sizeMinusOne =
1835 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1836 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1837 index);
1838 }
1839
1840 indices.push_back(index);
1841 }
1842
1843 auto extract = nestedBuilder.create<tensor::ExtractOp>(
1844 nestedLoc, input, indices);
1845 nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1846 extract.getResult());
1847 });
1848 return success();
1849 }
1850};
1851
1852// This converter translate a tile operation to a reshape, broadcast, reshape.
1853// The first reshape minimally expands each tiled dimension to include a
1854// proceding size-1 dim. This dim is then broadcasted to the appropriate
1855// multiple.
1856struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1857 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
1858
1859 LogicalResult
1860 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1861 ConversionPatternRewriter &rewriter) const override {
1862 auto loc = op.getLoc();
1863 auto input = op.getInput1();
1864 auto inputTy = cast<ShapedType>(input.getType());
1865 auto inputShape = inputTy.getShape();
1866 auto resultTy = cast<ShapedType>(op.getType());
1867 auto elementTy = inputTy.getElementType();
1868 int64_t rank = inputTy.getRank();
1869
1870 ArrayRef<int64_t> multiples = op.getMultiples();
1871
1872 // Broadcast the newly added dimensions to their appropriate multiple.
1873 SmallVector<int64_t, 2> genericShape;
1874 for (int i = 0; i < rank; i++) {
1875 int64_t dim = multiples[i];
1876 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1877 genericShape.push_back(Elt: inputShape[i]);
1878 }
1879
1880 SmallVector<Value> dynDims;
1881 for (int i = 0; i < inputTy.getRank(); i++) {
1882 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1883 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1884 }
1885 }
1886
1887 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1888 op.getLoc(), genericShape, elementTy, dynDims);
1889
1890 // We needs to map the input shape to the non-broadcasted dimensions.
1891 SmallVector<AffineExpr, 4> dimExprs;
1892 dimExprs.reserve(N: rank);
1893 for (unsigned i = 0; i < rank; ++i)
1894 dimExprs.push_back(Elt: rewriter.getAffineDimExpr(position: i * 2 + 1));
1895
1896 auto readAffineMap =
1897 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, results: dimExprs,
1898 context: rewriter.getContext());
1899
1900 SmallVector<AffineMap, 2> affineMaps = {
1901 readAffineMap, rewriter.getMultiDimIdentityMap(rank: genericShape.size())};
1902
1903 auto genericOp = rewriter.create<linalg::GenericOp>(
1904 loc, RankedTensorType::get(genericShape, elementTy), input,
1905 ValueRange{emptyTensor}, affineMaps,
1906 getNParallelLoopsAttrs(genericShape.size()),
1907 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1908 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1909 });
1910
1911 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1912 op, resultTy, genericOp.getResult(0),
1913 rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
1914 return success();
1915 }
1916};
1917
1918// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1919// op, producing two output buffers.
1920//
1921// The first output buffer contains the index of the found maximum value. It is
1922// initialized to 0 and is resulting integer type.
1923//
1924// The second output buffer contains the maximum value found. It is initialized
1925// to the minimum representable value of the input element type. After being
1926// populated by indexed_generic, this buffer is disgarded as only the index is
1927// requested.
1928//
1929// The indexed_generic op updates both the maximum value and index if the
1930// current value exceeds the running max.
1931class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1932public:
1933 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
1934
1935 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1936 PatternRewriter &rewriter) const final {
1937 auto loc = argmaxOp.getLoc();
1938 Value input = argmaxOp.getInput();
1939 auto inputTy = cast<ShapedType>(input.getType());
1940 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1941 auto inElementTy = inputTy.getElementType();
1942 auto outElementTy = resultTy.getElementType();
1943 int axis = argmaxOp.getAxis();
1944 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1945
1946 if (!isa<IntegerType>(outElementTy))
1947 return rewriter.notifyMatchFailure(
1948 argmaxOp,
1949 "tosa.arg_max to linalg.* requires integer-like result type");
1950
1951 SmallVector<Value> dynDims;
1952 for (int i = 0; i < inputTy.getRank(); i++) {
1953 if (inputTy.isDynamicDim(i) && i != axis) {
1954 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1955 }
1956 }
1957
1958 // First fill the output buffer for the index.
1959 auto emptyTensorIdx = rewriter
1960 .create<tensor::EmptyOp>(loc, resultTy.getShape(),
1961 outElementTy, dynDims)
1962 .getResult();
1963 auto fillValueIdx = rewriter.create<arith::ConstantOp>(
1964 loc, rewriter.getIntegerAttr(outElementTy, 0));
1965 auto filledTensorIdx =
1966 rewriter
1967 .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
1968 ValueRange{emptyTensorIdx})
1969 .result();
1970
1971 // Second fill the output buffer for the running max.
1972 auto emptyTensorMax = rewriter
1973 .create<tensor::EmptyOp>(loc, resultTy.getShape(),
1974 inElementTy, dynDims)
1975 .getResult();
1976 auto fillValueMaxAttr =
1977 createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
1978
1979 if (!fillValueMaxAttr)
1980 return rewriter.notifyMatchFailure(
1981 argmaxOp, "unsupported tosa.argmax element type");
1982
1983 auto fillValueMax =
1984 rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
1985 auto filledTensorMax =
1986 rewriter
1987 .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
1988 ValueRange{emptyTensorMax})
1989 .result();
1990
1991 // We need to reduce along the arg-max axis, with parallel operations along
1992 // the rest.
1993 SmallVector<utils::IteratorType, 4> iteratorTypes;
1994 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
1995 iteratorTypes[axis] = utils::IteratorType::reduction;
1996
1997 SmallVector<AffineExpr, 2> srcExprs;
1998 SmallVector<AffineExpr, 2> dstExprs;
1999 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2000 srcExprs.push_back(Elt: mlir::getAffineDimExpr(position: i, context: rewriter.getContext()));
2001 if (axis != i)
2002 dstExprs.push_back(Elt: mlir::getAffineDimExpr(position: i, context: rewriter.getContext()));
2003 }
2004
2005 bool didEncounterError = false;
2006 auto maps = AffineMap::inferFromExprList(exprsList: {srcExprs, dstExprs, dstExprs},
2007 context: rewriter.getContext());
2008 auto linalgOp = rewriter.create<linalg::GenericOp>(
2009 loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2010 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2011 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2012 ValueRange blockArgs) {
2013 auto newValue = blockArgs[0];
2014 auto oldIndex = blockArgs[1];
2015 auto oldValue = blockArgs[2];
2016
2017 Value newIndex = rewriter.create<arith::IndexCastOp>(
2018 nestedLoc, oldIndex.getType(),
2019 rewriter.create<linalg::IndexOp>(loc, axis));
2020
2021 Value predicate;
2022 if (isa<FloatType>(inElementTy)) {
2023 predicate = rewriter.create<arith::CmpFOp>(
2024 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2025 } else if (isa<IntegerType>(inElementTy)) {
2026 predicate = rewriter.create<arith::CmpIOp>(
2027 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2028 } else {
2029 didEncounterError = true;
2030 return;
2031 }
2032
2033 auto resultMax = rewriter.create<arith::SelectOp>(
2034 nestedLoc, predicate, newValue, oldValue);
2035 auto resultIndex = rewriter.create<arith::SelectOp>(
2036 nestedLoc, predicate, newIndex, oldIndex);
2037 nestedBuilder.create<linalg::YieldOp>(
2038 nestedLoc, ValueRange({resultIndex, resultMax}));
2039 });
2040
2041 if (didEncounterError)
2042 return rewriter.notifyMatchFailure(
2043 argmaxOp, "unsupported tosa.argmax element type");
2044
2045 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2046 return success();
2047 }
2048};
2049
2050class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2051public:
2052 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2053 LogicalResult
2054 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2055 ConversionPatternRewriter &rewriter) const final {
2056 auto input = adaptor.getOperands()[0];
2057 auto indices = adaptor.getOperands()[1];
2058
2059 auto valuesTy =
2060 dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2061 auto resultTy = cast<ShapedType>(op.getType());
2062
2063 if (!valuesTy)
2064 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2065
2066 auto dynamicDims = inferDynamicDimsForGather(
2067 builder&: rewriter, loc: op.getLoc(), values: adaptor.getValues(), indices: adaptor.getIndices());
2068
2069 auto resultElementTy = resultTy.getElementType();
2070
2071 auto loc = op.getLoc();
2072 auto emptyTensor =
2073 rewriter
2074 .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2075 dynamicDims)
2076 .getResult();
2077
2078 SmallVector<AffineMap, 2> affineMaps = {
2079 AffineMap::get(
2080 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2081 {rewriter.getAffineDimExpr(position: 0), rewriter.getAffineDimExpr(position: 1)},
2082 rewriter.getContext()),
2083 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())};
2084
2085 auto genericOp = rewriter.create<linalg::GenericOp>(
2086 loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2087 ValueRange{emptyTensor}, affineMaps,
2088 getNParallelLoopsAttrs(resultTy.getRank()),
2089 [&](OpBuilder &b, Location loc, ValueRange args) {
2090 auto indexValue = args[0];
2091 auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2092 Value index1 = rewriter.create<arith::IndexCastOp>(
2093 loc, rewriter.getIndexType(), indexValue);
2094 auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2095 Value extract = rewriter.create<tensor::ExtractOp>(
2096 loc, input, ValueRange{index0, index1, index2});
2097 rewriter.create<linalg::YieldOp>(loc, extract);
2098 });
2099 rewriter.replaceOp(op, genericOp.getResult(0));
2100 return success();
2101 }
2102
2103 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2104 Location loc,
2105 Value values,
2106 Value indices) {
2107 llvm::SmallVector<Value> results;
2108
2109 auto addDynamicDimension = [&](Value source, int64_t dim) {
2110 auto sz = tensor::getMixedSize(builder, loc, value: source, dim);
2111 if (auto dimValue = llvm::dyn_cast_if_present<Value>(Val&: sz))
2112 results.push_back(Elt: dimValue);
2113 };
2114
2115 addDynamicDimension(values, 0);
2116 addDynamicDimension(indices, 1);
2117 addDynamicDimension(values, 2);
2118 return results;
2119 }
2120};
2121
2122// Lowerings the TableOp to a series of gathers and numerica operations. This
2123// includes interpolation between the high/low values. For the I8 varient, this
2124// simplifies to a single gather operation.
2125class TableConverter : public OpRewritePattern<tosa::TableOp> {
2126public:
2127 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2128
2129 LogicalResult matchAndRewrite(tosa::TableOp op,
2130 PatternRewriter &rewriter) const final {
2131 auto loc = op.getLoc();
2132 Value input = op.getInput();
2133 Value table = op.getTable();
2134 auto inputTy = cast<ShapedType>(input.getType());
2135 auto tableTy = cast<ShapedType>(table.getType());
2136 auto resultTy = cast<ShapedType>(op.getType());
2137
2138 auto inputElementTy = inputTy.getElementType();
2139 auto tableElementTy = tableTy.getElementType();
2140 auto resultElementTy = resultTy.getElementType();
2141
2142 SmallVector<Value> dynDims;
2143 for (int i = 0; i < resultTy.getRank(); ++i) {
2144 if (inputTy.isDynamicDim(i)) {
2145 dynDims.push_back(
2146 rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2147 }
2148 }
2149
2150 auto emptyTensor = rewriter
2151 .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2152 resultElementTy, dynDims)
2153 .getResult();
2154
2155 SmallVector<AffineMap, 2> affineMaps = {
2156 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank()),
2157 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())};
2158
2159 auto genericOp = rewriter.create<linalg::GenericOp>(
2160 loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2161 getNParallelLoopsAttrs(resultTy.getRank()));
2162 rewriter.replaceOp(op, genericOp.getResult(0));
2163
2164 {
2165 OpBuilder::InsertionGuard regionGuard(rewriter);
2166 Block *block = rewriter.createBlock(
2167 &genericOp.getRegion(), genericOp.getRegion().end(),
2168 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2169
2170 auto inputValue = block->getArgument(i: 0);
2171 rewriter.setInsertionPointToStart(block);
2172 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2173 resultElementTy.isInteger(8)) {
2174 Value index = rewriter.create<arith::IndexCastOp>(
2175 loc, rewriter.getIndexType(), inputValue);
2176 Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2177 index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2178 index, offset);
2179 Value extract =
2180 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2181 rewriter.create<linalg::YieldOp>(loc, extract);
2182 return success();
2183 }
2184
2185 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2186 resultElementTy.isInteger(32)) {
2187 Value extend = rewriter.create<arith::ExtSIOp>(
2188 loc, rewriter.getI32Type(), inputValue);
2189
2190 auto offset = rewriter.create<arith::ConstantOp>(
2191 loc, rewriter.getI32IntegerAttr(32768));
2192 auto seven = rewriter.create<arith::ConstantOp>(
2193 loc, rewriter.getI32IntegerAttr(7));
2194 auto one = rewriter.create<arith::ConstantOp>(
2195 loc, rewriter.getI32IntegerAttr(1));
2196 auto b1111111 = rewriter.create<arith::ConstantOp>(
2197 loc, rewriter.getI32IntegerAttr(127));
2198
2199 // Compute the index and fractional part from the input value:
2200 // value = value + 32768
2201 // index = value >> 7;
2202 // fraction = 0x01111111 & value
2203 auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2204 Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2205 Value fraction =
2206 rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2207
2208 // Extract the base and next values from the table.
2209 // base = (int32_t) table[index];
2210 // next = (int32_t) table[index + 1];
2211 Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2212
2213 index = rewriter.create<arith::IndexCastOp>(
2214 loc, rewriter.getIndexType(), index);
2215 indexPlusOne = rewriter.create<arith::IndexCastOp>(
2216 loc, rewriter.getIndexType(), indexPlusOne);
2217
2218 Value base =
2219 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2220 Value next = rewriter.create<tensor::ExtractOp>(
2221 loc, table, ValueRange{indexPlusOne});
2222
2223 base =
2224 rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2225 next =
2226 rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2227
2228 // Use the fractional part to interpolate between the input values:
2229 // result = (base << 7) + (next - base) * fraction
2230 Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2231 Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2232 Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2233 Value result =
2234 rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2235
2236 rewriter.create<linalg::YieldOp>(loc, result);
2237
2238 return success();
2239 }
2240 }
2241
2242 return rewriter.notifyMatchFailure(
2243 op, "unable to create body for tosa.table op");
2244 }
2245};
2246
2247struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2248 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2249
2250 static bool isRankedTensor(Type type) { return isa<RankedTensorType>(Val: type); }
2251
2252 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2253 OpFoldResult ofr) {
2254 auto one = builder.create<arith::ConstantIndexOp>(location: loc, args: 1);
2255 auto two = builder.create<arith::ConstantIndexOp>(location: loc, args: 2);
2256
2257 auto value = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr);
2258 auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2259 auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2260 return getAsOpFoldResult(plusOne);
2261 }
2262
2263 static RankedTensorType
2264 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2265 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2266 // Get [N, H, W]
2267 auto dims = tensor::getMixedSizes(builder, loc, value: input);
2268
2269 // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2270 // output tensors.
2271 dims[2] = halfPlusOne(builder, loc, ofr: dims[2]);
2272
2273 llvm::SmallVector<int64_t, 3> staticSizes;
2274 dispatchIndexOpFoldResults(ofrs: dims, dynamicVec&: dynamicSizes, staticVec&: staticSizes);
2275
2276 auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2277 return RankedTensorType::get(staticSizes, elementType);
2278 }
2279
2280 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2281 RankedTensorType type,
2282 llvm::ArrayRef<Value> dynamicSizes) {
2283 auto emptyTensor =
2284 rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2285 auto fillValueAttr = rewriter.getZeroAttr(type: type.getElementType());
2286 auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2287 auto filledTensor = rewriter
2288 .create<linalg::FillOp>(loc, ValueRange{fillValue},
2289 ValueRange{emptyTensor})
2290 .result();
2291 return filledTensor;
2292 }
2293
2294 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2295 FloatType type, Value value) {
2296 auto integerVal = builder.create<arith::IndexCastUIOp>(
2297 loc,
2298 type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2299 : builder.getI32Type(),
2300 value);
2301
2302 return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2303 }
2304
2305 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2306 FloatType type, int64_t index) {
2307 auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2308 return castIndexToFloat(builder, loc, type, value: indexVal);
2309 }
2310
2311 template <typename... Args>
2312 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2313 Args... args) {
2314 return {builder.getAffineDimExpr(position: args)...};
2315 }
2316
2317 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2318 PatternRewriter &rewriter) const override {
2319 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2320 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2321 return rewriter.notifyMatchFailure(rfft2d,
2322 "only supports ranked tensors");
2323 }
2324
2325 auto loc = rfft2d.getLoc();
2326 auto input = rfft2d.getInput();
2327 auto elementType =
2328 cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2329
2330 // Compute the output type and set of dynamic sizes
2331 llvm::SmallVector<Value> dynamicSizes;
2332 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2333
2334 // Iterator types for the linalg.generic implementation
2335 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2336 utils::IteratorType::parallel, utils::IteratorType::parallel,
2337 utils::IteratorType::parallel, utils::IteratorType::reduction,
2338 utils::IteratorType::reduction};
2339
2340 // Inputs/outputs to the linalg.generic implementation
2341 llvm::SmallVector<Value> genericOpInputs = {input};
2342 llvm::SmallVector<Value> genericOpOutputs = {
2343 createZeroTensor(rewriter, loc: loc, type: outputType, dynamicSizes),
2344 createZeroTensor(rewriter, loc: loc, type: outputType, dynamicSizes)};
2345
2346 // Indexing maps for input and output tensors
2347 auto indexingMaps = AffineMap::inferFromExprList(
2348 exprsList: llvm::ArrayRef{affineDimsExpr(builder&: rewriter, args: 0, args: 3, args: 4),
2349 affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2),
2350 affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2)},
2351 context: rewriter.getContext());
2352
2353 // Width and height dimensions of the original input.
2354 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2355 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2356
2357 // Constants and dimension sizes
2358 auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2359 auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2360 auto constH = castIndexToFloat(builder&: rewriter, loc: loc, type: elementType, value: dimH);
2361 auto constW = castIndexToFloat(builder&: rewriter, loc: loc, type: elementType, value: dimW);
2362
2363 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2364 Value valReal = args[0];
2365 Value sumReal = args[1];
2366 Value sumImag = args[2];
2367
2368 // Indices for angle computation
2369 Value oy = builder.create<linalg::IndexOp>(loc, 1);
2370 Value ox = builder.create<linalg::IndexOp>(loc, 2);
2371 Value iy = builder.create<linalg::IndexOp>(loc, 3);
2372 Value ix = builder.create<linalg::IndexOp>(loc, 4);
2373
2374 // Calculating angle without integer parts of components as sin/cos are
2375 // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2376 // / W);
2377 auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2378 auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2379
2380 auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2381 auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2382
2383 auto iyRemFloat = castIndexToFloat(builder, loc, type: elementType, value: iyRem);
2384 auto ixRemFloat = castIndexToFloat(builder, loc, type: elementType, value: ixRem);
2385
2386 auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2387 auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2388 auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2389 auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2390
2391 // realComponent = valReal * cos(angle)
2392 // imagComponent = valReal * sin(angle)
2393 auto cosAngle = builder.create<math::CosOp>(loc, angle);
2394 auto sinAngle = builder.create<math::SinOp>(loc, angle);
2395 auto realComponent =
2396 builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2397 auto imagComponent =
2398 builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2399
2400 // outReal = sumReal + realComponent
2401 // outImag = sumImag - imagComponent
2402 auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2403 auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2404
2405 builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2406 };
2407
2408 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2409 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2410 indexingMaps, iteratorTypes, buildBody);
2411
2412 return success();
2413 }
2414};
2415
2416struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2417 using OpRewritePattern::OpRewritePattern;
2418
2419 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2420 PatternRewriter &rewriter) const override {
2421 if (!llvm::all_of(fft2d->getOperandTypes(),
2422 RFFT2dConverter::isRankedTensor) ||
2423 !llvm::all_of(fft2d->getResultTypes(),
2424 RFFT2dConverter::isRankedTensor)) {
2425 return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2426 }
2427
2428 Location loc = fft2d.getLoc();
2429 Value input_real = fft2d.getInputReal();
2430 Value input_imag = fft2d.getInputImag();
2431 BoolAttr inverse = fft2d.getInverseAttr();
2432
2433 auto real_el_ty = cast<FloatType>(
2434 cast<ShapedType>(input_real.getType()).getElementType());
2435 [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2436 cast<ShapedType>(input_imag.getType()).getElementType());
2437
2438 assert(real_el_ty == imag_el_ty);
2439
2440 // Compute the output type and set of dynamic sizes
2441 SmallVector<Value> dynamicSizes;
2442
2443 // Get [N, H, W]
2444 auto dims = tensor::getMixedSizes(builder&: rewriter, loc, value: input_real);
2445
2446 SmallVector<int64_t, 3> staticSizes;
2447 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2448
2449 auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2450
2451 // Iterator types for the linalg.generic implementation
2452 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2453 utils::IteratorType::parallel, utils::IteratorType::parallel,
2454 utils::IteratorType::parallel, utils::IteratorType::reduction,
2455 utils::IteratorType::reduction};
2456
2457 // Inputs/outputs to the linalg.generic implementation
2458 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2459 SmallVector<Value> genericOpOutputs = {
2460 RFFT2dConverter::createZeroTensor(rewriter, loc, type: outputType,
2461 dynamicSizes),
2462 RFFT2dConverter::createZeroTensor(rewriter, loc, type: outputType,
2463 dynamicSizes)};
2464
2465 // Indexing maps for input and output tensors
2466 auto indexingMaps = AffineMap::inferFromExprList(
2467 exprsList: ArrayRef{RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 3, args: 4),
2468 RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 3, args: 4),
2469 RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2),
2470 RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2)},
2471 context: rewriter.getContext());
2472
2473 // Width and height dimensions of the original input.
2474 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2475 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2476
2477 // Constants and dimension sizes
2478 auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2479 auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2480 Value constH =
2481 RFFT2dConverter::castIndexToFloat(builder&: rewriter, loc, type: real_el_ty, value: dimH);
2482 Value constW =
2483 RFFT2dConverter::castIndexToFloat(builder&: rewriter, loc, type: real_el_ty, value: dimW);
2484
2485 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2486 Value valReal = args[0];
2487 Value valImag = args[1];
2488 Value sumReal = args[2];
2489 Value sumImag = args[3];
2490
2491 // Indices for angle computation
2492 Value oy = builder.create<linalg::IndexOp>(loc, 1);
2493 Value ox = builder.create<linalg::IndexOp>(loc, 2);
2494 Value iy = builder.create<linalg::IndexOp>(loc, 3);
2495 Value ix = builder.create<linalg::IndexOp>(loc, 4);
2496
2497 // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2498 // ox) % W ) / W);
2499 auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2500 auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2501
2502 auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2503 auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2504
2505 auto iyRemFloat =
2506 RFFT2dConverter::castIndexToFloat(builder, loc, type: real_el_ty, value: iyRem);
2507 auto ixRemFloat =
2508 RFFT2dConverter::castIndexToFloat(builder, loc, type: real_el_ty, value: ixRem);
2509
2510 auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2511 auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2512
2513 auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2514 auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2515
2516 if (inverse.getValue()) {
2517 angle = builder.create<arith::MulFOp>(
2518 loc, angle,
2519 rewriter.create<arith::ConstantOp>(
2520 loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2521 }
2522
2523 // realComponent = val_real * cos(a) + val_imag * sin(a);
2524 // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2525 auto cosAngle = builder.create<math::CosOp>(loc, angle);
2526 auto sinAngle = builder.create<math::SinOp>(loc, angle);
2527
2528 auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2529 auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2530 auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2531
2532 auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2533 auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2534
2535 auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2536
2537 // outReal = sumReal + realComponent
2538 // outImag = sumImag - imagComponent
2539 auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2540 auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2541
2542 builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2543 };
2544
2545 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2546 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2547 indexingMaps, iteratorTypes, buildBody);
2548
2549 return success();
2550 }
2551};
2552
2553} // namespace
2554
2555void mlir::tosa::populateTosaToLinalgConversionPatterns(
2556 RewritePatternSet *patterns) {
2557
2558 // We have multiple resize coverters to handle degenerate cases.
2559 patterns->add<GenericResizeConverter>(arg: patterns->getContext(),
2560 /*benefit=*/args: 100);
2561 patterns->add<ResizeUnaryConverter>(arg: patterns->getContext(),
2562 /*benefit=*/args: 200);
2563 patterns->add<MaterializeResizeBroadcast>(arg: patterns->getContext(),
2564 /*benefit=*/args: 300);
2565
2566 patterns->add<
2567 // clang-format off
2568 PointwiseConverter<tosa::AddOp>,
2569 PointwiseConverter<tosa::SubOp>,
2570 PointwiseConverter<tosa::MulOp>,
2571 PointwiseConverter<tosa::DivOp>,
2572 PointwiseConverter<tosa::NegateOp>,
2573 PointwiseConverter<tosa::PowOp>,
2574 PointwiseConverter<tosa::ReciprocalOp>,
2575 PointwiseConverter<tosa::RsqrtOp>,
2576 PointwiseConverter<tosa::LogOp>,
2577 PointwiseConverter<tosa::ExpOp>,
2578 PointwiseConverter<tosa::AbsOp>,
2579 PointwiseConverter<tosa::TanhOp>,
2580 PointwiseConverter<tosa::ErfOp>,
2581 PointwiseConverter<tosa::BitwiseAndOp>,
2582 PointwiseConverter<tosa::BitwiseOrOp>,
2583 PointwiseConverter<tosa::BitwiseNotOp>,
2584 PointwiseConverter<tosa::BitwiseXorOp>,
2585 PointwiseConverter<tosa::LogicalAndOp>,
2586 PointwiseConverter<tosa::LogicalNotOp>,
2587 PointwiseConverter<tosa::LogicalOrOp>,
2588 PointwiseConverter<tosa::LogicalXorOp>,
2589 PointwiseConverter<tosa::CastOp>,
2590 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2591 PointwiseConverter<tosa::LogicalRightShiftOp>,
2592 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2593 PointwiseConverter<tosa::ClzOp>,
2594 PointwiseConverter<tosa::SelectOp>,
2595 PointwiseConverter<tosa::GreaterOp>,
2596 PointwiseConverter<tosa::GreaterEqualOp>,
2597 PointwiseConverter<tosa::EqualOp>,
2598 PointwiseConverter<tosa::MaximumOp>,
2599 PointwiseConverter<tosa::MinimumOp>,
2600 PointwiseConverter<tosa::CeilOp>,
2601 PointwiseConverter<tosa::FloorOp>,
2602 PointwiseConverter<tosa::ClampOp>,
2603 PointwiseConverter<tosa::SigmoidOp>,
2604 IdentityNConverter<tosa::IdentityOp>,
2605 ReduceConverter<tosa::ReduceAllOp>,
2606 ReduceConverter<tosa::ReduceAnyOp>,
2607 ReduceConverter<tosa::ReduceMinOp>,
2608 ReduceConverter<tosa::ReduceMaxOp>,
2609 ReduceConverter<tosa::ReduceSumOp>,
2610 ReduceConverter<tosa::ReduceProdOp>,
2611 ArgMaxConverter,
2612 GatherConverter,
2613 RescaleConverter,
2614 ReverseConverter,
2615 RFFT2dConverter,
2616 FFT2dConverter,
2617 TableConverter,
2618 TileConverter>(patterns->getContext());
2619 // clang-format on
2620}
2621

source code of mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp