1//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/Index/IR/IndexOps.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/Math/IR/Math.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/Dialect/Tosa/IR/TosaOps.h"
22#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
23#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24#include "mlir/Dialect/Utils/StaticValueUtils.h"
25#include "mlir/IR/ImplicitLocOpBuilder.h"
26#include "mlir/IR/Matchers.h"
27#include "mlir/IR/OpDefinition.h"
28#include "mlir/IR/PatternMatch.h"
29#include "mlir/Transforms/DialectConversion.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/Sequence.h"
32
33#include <type_traits>
34
35using namespace mlir;
36using namespace mlir::tosa;
37
38// Helper function to materialize the semantically correct compare and select
39// operations given a binary operation with a specific NaN propagation mode.
40//
41// In the case of "PROPAGATE" semantics no compare and selection is required and
42// this function does nothing.
43//
44// In the case of "IGNORE" semantics this function materializes a comparison of
45// the current operands to the op which will return true for any NaN
46// argument and then selects between the non-NaN operation argument and the
47// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
48// code:
49//
50// In the case that the op is operating on non floating point types we ignore
51// the attribute completely, this is consistent with the TOSA spec which has
52// the following wording: "This attribute is ignored by non floating-point
53// types."
54//
55// binary<op>(lhs, rhs):
56// result = op(lhs, rhs)
57// if lhs == NaN return rhs
58// if rhs == NaN return lhs
59// return result
60template <typename OpTy>
61static Value
62materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
63 Value lhs, Value rhs, Value result) {
64 // NaN propagation has no meaning for non floating point types.
65 if (!isa<FloatType>(Val: getElementTypeOrSelf(val: lhs)))
66 return result;
67
68 auto nanMode = op.getNanMode();
69 if (nanMode == "PROPAGATE")
70 return result;
71
72 // Unordered comparison of NaN against itself will always return true.
73 Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
74 op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
75 Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
76 op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
77 Value rhsOrResult =
78 rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
79 return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
80 rhsOrResult);
81}
82
83static Value createLinalgBodyCalculationForElementwiseOp(
84 Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
85 ConversionPatternRewriter &rewriter) {
86 Location loc = op->getLoc();
87 auto elementTy =
88 cast<ShapedType>(Val: op->getOperand(idx: 0).getType()).getElementType();
89
90 // tosa::AbsOp
91 if (isa<tosa::AbsOp>(Val: op) && isa<FloatType>(Val: elementTy))
92 return rewriter.create<math::AbsFOp>(location: loc, args&: resultTypes, args);
93
94 if (isa<tosa::AbsOp>(Val: op) && isa<IntegerType>(Val: elementTy)) {
95 auto zero = rewriter.create<arith::ConstantOp>(
96 location: loc, args: rewriter.getZeroAttr(type: elementTy));
97 auto neg = rewriter.create<arith::SubIOp>(location: loc, args&: zero, args: args[0]);
98 return rewriter.create<arith::MaxSIOp>(location: loc, args: args[0], args&: neg);
99 }
100
101 // tosa::AddOp
102 if (isa<tosa::AddOp>(Val: op) && isa<FloatType>(Val: elementTy))
103 return rewriter.create<arith::AddFOp>(location: loc, args&: resultTypes, args);
104
105 if (isa<tosa::AddOp>(Val: op) && isa<IntegerType>(Val: elementTy))
106 return rewriter.create<arith::AddIOp>(location: loc, args&: resultTypes, args);
107
108 // tosa::SubOp
109 if (isa<tosa::SubOp>(Val: op) && isa<FloatType>(Val: elementTy))
110 return rewriter.create<arith::SubFOp>(location: loc, args&: resultTypes, args);
111
112 if (isa<tosa::SubOp>(Val: op) && isa<IntegerType>(Val: elementTy))
113 return rewriter.create<arith::SubIOp>(location: loc, args&: resultTypes, args);
114
115 // tosa::IntDivOp
116 if (isa<tosa::IntDivOp>(Val: op) && isa<IntegerType>(Val: elementTy))
117 return rewriter.create<arith::DivSIOp>(location: loc, args&: resultTypes, args);
118
119 // tosa::ReciprocalOp
120 if (isa<tosa::ReciprocalOp>(Val: op) && isa<FloatType>(Val: elementTy)) {
121 auto one =
122 rewriter.create<arith::ConstantOp>(location: loc, args: FloatAttr::get(type: elementTy, value: 1));
123 return rewriter.create<arith::DivFOp>(location: loc, args&: resultTypes, args&: one, args: args[0]);
124 }
125
126 // tosa::MulOp
127 if (isa<tosa::MulOp>(Val: op)) {
128 auto shiftVal = cast<tosa::MulOp>(Val: op).getShift();
129 DenseElementsAttr shiftElem;
130 if (!matchPattern(value: shiftVal, pattern: m_Constant(bind_value: &shiftElem))) {
131 (void)rewriter.notifyMatchFailure(arg&: op, msg: "shift value of mul not found");
132 return nullptr;
133 }
134
135 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
136
137 if (isa<FloatType>(Val: elementTy)) {
138 if (shift != 0) {
139 (void)rewriter.notifyMatchFailure(arg&: op,
140 msg: "Cannot have shift value for float");
141 return nullptr;
142 }
143 return rewriter.create<arith::MulFOp>(location: loc, args&: resultTypes, args: args[0], args: args[1]);
144 }
145
146 if (isa<IntegerType>(Val: elementTy)) {
147 Value a = args[0];
148 Value b = args[1];
149
150 if (shift > 0) {
151 auto shiftConst =
152 rewriter.create<arith::ConstantIntOp>(location: loc, args&: shift, /*bitwidth=*/args: 8);
153 if (!a.getType().isInteger(width: 32))
154 a = rewriter.create<arith::ExtSIOp>(location: loc, args: rewriter.getI32Type(), args&: a);
155
156 if (!b.getType().isInteger(width: 32))
157 b = rewriter.create<arith::ExtSIOp>(location: loc, args: rewriter.getI32Type(), args&: b);
158
159 auto result = rewriter.create<tosa::ApplyScaleOp>(
160 location: loc, args: rewriter.getI32Type(), args&: a, args&: b, args&: shiftConst,
161 args: rewriter.getStringAttr(bytes: "SINGLE_ROUND"));
162
163 if (elementTy.isInteger(width: 32))
164 return result;
165
166 return rewriter.create<arith::TruncIOp>(location: loc, args&: elementTy, args&: result);
167 }
168
169 int aWidth = a.getType().getIntOrFloatBitWidth();
170 int bWidth = b.getType().getIntOrFloatBitWidth();
171 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
172
173 if (aWidth < cWidth)
174 a = rewriter.create<arith::ExtSIOp>(location: loc, args: resultTypes[0], args&: a);
175 if (bWidth < cWidth)
176 b = rewriter.create<arith::ExtSIOp>(location: loc, args: resultTypes[0], args&: b);
177
178 return rewriter.create<arith::MulIOp>(location: loc, args&: resultTypes, args&: a, args&: b);
179 }
180 }
181
182 // tosa::NegateOp
183 if (isa<tosa::NegateOp>(Val: op)) {
184 auto negate = cast<tosa::NegateOp>(Val: op);
185
186 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
187 if (failed(Result: maybeInZp)) {
188 (void)rewriter.notifyMatchFailure(
189 arg&: op, msg: "input1 zero point cannot be statically determined");
190 return nullptr;
191 }
192
193 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
194 if (failed(Result: maybeOutZp)) {
195 (void)rewriter.notifyMatchFailure(
196 arg&: op, msg: "output zero point cannot be statically determined");
197 return nullptr;
198 }
199
200 int64_t inZp = *maybeInZp;
201 int64_t outZp = *maybeOutZp;
202
203 if (isa<FloatType>(Val: elementTy))
204 return rewriter.create<arith::NegFOp>(location: loc, args&: resultTypes, args: args[0]);
205
206 if (isa<IntegerType>(Val: elementTy)) {
207 if (!inZp && !outZp) {
208 auto constant = rewriter.create<arith::ConstantOp>(
209 location: loc, args: IntegerAttr::get(type: elementTy, value: 0));
210 return rewriter.create<arith::SubIOp>(location: loc, args&: resultTypes, args&: constant,
211 args: args[0]);
212 }
213
214 // Compute the maximum value that can occur in the intermediate buffer.
215 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
216 const int64_t zpAdd = inZp + outZp;
217 const int64_t maxValue =
218 APInt::getSignedMaxValue(numBits: inputBitWidth).getSExtValue() +
219 std::abs(i: zpAdd) + 1;
220
221 // Convert that maximum value into the maximum bitwidth needed to
222 // represent it. We assume 48-bit numbers may be supported further in
223 // the pipeline.
224 int intermediateBitWidth = 64;
225 if (maxValue <= APInt::getSignedMaxValue(numBits: 16).getSExtValue()) {
226 intermediateBitWidth = 16;
227 } else if (maxValue <= APInt::getSignedMaxValue(numBits: 32).getSExtValue()) {
228 intermediateBitWidth = 32;
229 } else if (maxValue <= APInt::getSignedMaxValue(numBits: 48).getSExtValue()) {
230 intermediateBitWidth = 48;
231 }
232
233 Type intermediateType = rewriter.getIntegerType(width: intermediateBitWidth);
234 Value zpAddValue = rewriter.create<arith::ConstantOp>(
235 location: loc, args: rewriter.getIntegerAttr(type: intermediateType, value: zpAdd));
236
237 // The negation can be applied by doing:
238 // outputValue = inZp + outZp - inputValue
239 auto ext =
240 rewriter.create<arith::ExtSIOp>(location: loc, args&: intermediateType, args: args[0]);
241 auto sub = rewriter.create<arith::SubIOp>(location: loc, args&: zpAddValue, args&: ext);
242
243 // Clamp to the negation range.
244 Value min = rewriter.create<arith::ConstantIntOp>(
245 location: loc, args&: intermediateType,
246 args: APInt::getSignedMinValue(numBits: inputBitWidth).getSExtValue());
247 Value max = rewriter.create<arith::ConstantIntOp>(
248 location: loc, args&: intermediateType,
249 args: APInt::getSignedMaxValue(numBits: inputBitWidth).getSExtValue());
250 auto clamp = clampIntHelper(loc, arg: sub, min, max, rewriter, isUnsigned: false);
251
252 // Truncate to the final value.
253 return rewriter.create<arith::TruncIOp>(location: loc, args&: elementTy, args&: clamp);
254 }
255 }
256
257 // tosa::BitwiseAndOp
258 if (isa<tosa::BitwiseAndOp>(Val: op) && isa<IntegerType>(Val: elementTy))
259 return rewriter.create<arith::AndIOp>(location: loc, args&: resultTypes, args);
260
261 // tosa::BitwiseOrOp
262 if (isa<tosa::BitwiseOrOp>(Val: op) && isa<IntegerType>(Val: elementTy))
263 return rewriter.create<arith::OrIOp>(location: loc, args&: resultTypes, args);
264
265 // tosa::BitwiseNotOp
266 if (isa<tosa::BitwiseNotOp>(Val: op) && isa<IntegerType>(Val: elementTy)) {
267 auto allOnesAttr = rewriter.getIntegerAttr(
268 type: elementTy, value: APInt::getAllOnes(numBits: elementTy.getIntOrFloatBitWidth()));
269 auto allOnes = rewriter.create<arith::ConstantOp>(location: loc, args&: allOnesAttr);
270 return rewriter.create<arith::XOrIOp>(location: loc, args&: resultTypes, args: args[0], args&: allOnes);
271 }
272
273 // tosa::BitwiseXOrOp
274 if (isa<tosa::BitwiseXorOp>(Val: op) && isa<IntegerType>(Val: elementTy))
275 return rewriter.create<arith::XOrIOp>(location: loc, args&: resultTypes, args);
276
277 // tosa::LogicalLeftShiftOp
278 if (isa<tosa::LogicalLeftShiftOp>(Val: op) && isa<IntegerType>(Val: elementTy))
279 return rewriter.create<arith::ShLIOp>(location: loc, args&: resultTypes, args);
280
281 // tosa::LogicalRightShiftOp
282 if (isa<tosa::LogicalRightShiftOp>(Val: op) && isa<IntegerType>(Val: elementTy))
283 return rewriter.create<arith::ShRUIOp>(location: loc, args&: resultTypes, args);
284
285 // tosa::ArithmeticRightShiftOp
286 if (isa<tosa::ArithmeticRightShiftOp>(Val: op) && isa<IntegerType>(Val: elementTy)) {
287 auto result = rewriter.create<arith::ShRSIOp>(location: loc, args&: resultTypes, args);
288 auto round = cast<BoolAttr>(Val: op->getAttr(name: "round")).getValue();
289 if (!round) {
290 return result;
291 }
292
293 Type i1Ty = IntegerType::get(context: rewriter.getContext(), /*width=*/1);
294 auto one =
295 rewriter.create<arith::ConstantOp>(location: loc, args: IntegerAttr::get(type: elementTy, value: 1));
296 auto zero =
297 rewriter.create<arith::ConstantOp>(location: loc, args: IntegerAttr::get(type: elementTy, value: 0));
298 auto i1one =
299 rewriter.create<arith::ConstantOp>(location: loc, args: IntegerAttr::get(type: i1Ty, value: 1));
300
301 // Checking that input2 != 0
302 auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
303 location: loc, args: arith::CmpIPredicate::sgt, args: args[1], args&: zero);
304
305 // Checking for the last bit of input1 to be 1
306 auto subtract =
307 rewriter.create<arith::SubIOp>(location: loc, args&: resultTypes, args: args[1], args&: one);
308 auto shifted =
309 rewriter.create<arith::ShRSIOp>(location: loc, args&: resultTypes, args: args[0], args&: subtract)
310 ->getResults();
311 auto truncated = rewriter.create<arith::TruncIOp>(
312 location: loc, args&: i1Ty, args&: shifted, args: ArrayRef<NamedAttribute>());
313 auto isInputOdd =
314 rewriter.create<arith::AndIOp>(location: loc, args&: i1Ty, args&: truncated, args&: i1one);
315
316 auto shouldRound = rewriter.create<arith::AndIOp>(
317 location: loc, args&: i1Ty, args&: shiftValueGreaterThanZero, args&: isInputOdd);
318 auto extended =
319 rewriter.create<arith::ExtUIOp>(location: loc, args&: resultTypes, args&: shouldRound);
320 return rewriter.create<arith::AddIOp>(location: loc, args&: resultTypes, args&: result, args&: extended);
321 }
322
323 // tosa::ClzOp
324 if (isa<tosa::ClzOp>(Val: op) && isa<IntegerType>(Val: elementTy)) {
325 return rewriter.create<math::CountLeadingZerosOp>(location: loc, args&: elementTy, args: args[0]);
326 }
327
328 // tosa::LogicalAnd
329 if (isa<tosa::LogicalAndOp>(Val: op) && elementTy.isInteger(width: 1))
330 return rewriter.create<arith::AndIOp>(location: loc, args&: resultTypes, args);
331
332 // tosa::LogicalNot
333 if (isa<tosa::LogicalNotOp>(Val: op) && elementTy.isInteger(width: 1)) {
334 auto one = rewriter.create<arith::ConstantOp>(
335 location: loc, args: rewriter.getIntegerAttr(type: elementTy, value: 1));
336 return rewriter.create<arith::XOrIOp>(location: loc, args&: resultTypes, args: args[0], args&: one);
337 }
338
339 // tosa::LogicalOr
340 if (isa<tosa::LogicalOrOp>(Val: op) && elementTy.isInteger(width: 1))
341 return rewriter.create<arith::OrIOp>(location: loc, args&: resultTypes, args);
342
343 // tosa::LogicalXor
344 if (isa<tosa::LogicalXorOp>(Val: op) && elementTy.isInteger(width: 1))
345 return rewriter.create<arith::XOrIOp>(location: loc, args&: resultTypes, args);
346
347 // tosa::PowOp
348 if (isa<tosa::PowOp>(Val: op) && isa<FloatType>(Val: elementTy))
349 return rewriter.create<mlir::math::PowFOp>(location: loc, args&: resultTypes, args);
350
351 // tosa::RsqrtOp
352 if (isa<tosa::RsqrtOp>(Val: op) && isa<FloatType>(Val: elementTy))
353 return rewriter.create<mlir::math::RsqrtOp>(location: loc, args&: resultTypes, args);
354
355 // tosa::LogOp
356 if (isa<tosa::LogOp>(Val: op) && isa<FloatType>(Val: elementTy))
357 return rewriter.create<mlir::math::LogOp>(location: loc, args&: resultTypes, args);
358
359 // tosa::ExpOp
360 if (isa<tosa::ExpOp>(Val: op) && isa<FloatType>(Val: elementTy))
361 return rewriter.create<mlir::math::ExpOp>(location: loc, args&: resultTypes, args);
362
363 // tosa::SinOp
364 if (isa<tosa::SinOp>(Val: op) && isa<FloatType>(Val: elementTy))
365 return rewriter.create<mlir::math::SinOp>(location: loc, args&: resultTypes, args);
366
367 // tosa::CosOp
368 if (isa<tosa::CosOp>(Val: op) && isa<FloatType>(Val: elementTy))
369 return rewriter.create<mlir::math::CosOp>(location: loc, args&: resultTypes, args);
370
371 // tosa::TanhOp
372 if (isa<tosa::TanhOp>(Val: op) && isa<FloatType>(Val: elementTy))
373 return rewriter.create<mlir::math::TanhOp>(location: loc, args&: resultTypes, args);
374
375 // tosa::ErfOp
376 if (isa<tosa::ErfOp>(Val: op) && llvm::isa<FloatType>(Val: elementTy))
377 return rewriter.create<mlir::math::ErfOp>(location: loc, args&: resultTypes, args);
378
379 // tosa::GreaterOp
380 if (isa<tosa::GreaterOp>(Val: op) && isa<FloatType>(Val: elementTy))
381 return rewriter.create<arith::CmpFOp>(location: loc, args: arith::CmpFPredicate::OGT,
382 args: args[0], args: args[1]);
383
384 if (isa<tosa::GreaterOp>(Val: op) && elementTy.isSignlessInteger())
385 return rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::sgt,
386 args: args[0], args: args[1]);
387
388 // tosa::GreaterEqualOp
389 if (isa<tosa::GreaterEqualOp>(Val: op) && isa<FloatType>(Val: elementTy))
390 return rewriter.create<arith::CmpFOp>(location: loc, args: arith::CmpFPredicate::OGE,
391 args: args[0], args: args[1]);
392
393 if (isa<tosa::GreaterEqualOp>(Val: op) && elementTy.isSignlessInteger())
394 return rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::sge,
395 args: args[0], args: args[1]);
396
397 // tosa::EqualOp
398 if (isa<tosa::EqualOp>(Val: op) && isa<FloatType>(Val: elementTy))
399 return rewriter.create<arith::CmpFOp>(location: loc, args: arith::CmpFPredicate::OEQ,
400 args: args[0], args: args[1]);
401
402 if (isa<tosa::EqualOp>(Val: op) && elementTy.isSignlessInteger())
403 return rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::eq,
404 args: args[0], args: args[1]);
405
406 // tosa::SelectOp
407 if (isa<tosa::SelectOp>(Val: op)) {
408 elementTy = cast<ShapedType>(Val: op->getOperand(idx: 1).getType()).getElementType();
409 if (isa<FloatType>(Val: elementTy) || isa<IntegerType>(Val: elementTy))
410 return rewriter.create<arith::SelectOp>(location: loc, args: args[0], args: args[1], args: args[2]);
411 }
412
413 // tosa::MaximumOp
414 if (isa<tosa::MaximumOp>(Val: op) && isa<FloatType>(Val: elementTy)) {
415 auto max = rewriter.create<arith::MaximumFOp>(location: loc, args: args[0], args: args[1]);
416 return materializeBinaryNanCheckIfRequired(op: llvm::cast<tosa::MaximumOp>(Val: op),
417 rewriter, lhs: args[0], rhs: args[1], result: max);
418 }
419
420 if (isa<tosa::MaximumOp>(Val: op) && elementTy.isSignlessInteger()) {
421 return rewriter.create<arith::MaxSIOp>(location: loc, args: args[0], args: args[1]);
422 }
423
424 // tosa::MinimumOp
425 if (isa<tosa::MinimumOp>(Val: op) && isa<FloatType>(Val: elementTy)) {
426 auto min = rewriter.create<arith::MinimumFOp>(location: loc, args: args[0], args: args[1]);
427 return materializeBinaryNanCheckIfRequired(op: llvm::cast<tosa::MinimumOp>(Val: op),
428 rewriter, lhs: args[0], rhs: args[1], result: min);
429 }
430
431 if (isa<tosa::MinimumOp>(Val: op) && elementTy.isSignlessInteger()) {
432 return rewriter.create<arith::MinSIOp>(location: loc, args: args[0], args: args[1]);
433 }
434
435 // tosa::CeilOp
436 if (isa<tosa::CeilOp>(Val: op) && isa<FloatType>(Val: elementTy))
437 return rewriter.create<math::CeilOp>(location: loc, args&: resultTypes, args);
438
439 // tosa::FloorOp
440 if (isa<tosa::FloorOp>(Val: op) && isa<FloatType>(Val: elementTy))
441 return rewriter.create<math::FloorOp>(location: loc, args&: resultTypes, args);
442
443 // tosa::ClampOp
444 if (isa<tosa::ClampOp>(Val: op) && isa<FloatType>(Val: elementTy)) {
445 bool losesInfo = false;
446 APFloat minApf = cast<FloatAttr>(Val: op->getAttr(name: "min_val")).getValue();
447 APFloat maxApf = cast<FloatAttr>(Val: op->getAttr(name: "max_val")).getValue();
448 minApf.convert(ToSemantics: cast<FloatType>(Val&: elementTy).getFloatSemantics(),
449 RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo);
450 maxApf.convert(ToSemantics: cast<FloatType>(Val&: elementTy).getFloatSemantics(),
451 RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo);
452 auto min = rewriter.create<arith::ConstantOp>(
453 location: loc, args&: elementTy, args: rewriter.getFloatAttr(type: elementTy, value: minApf));
454 auto max = rewriter.create<arith::ConstantOp>(
455 location: loc, args&: elementTy, args: rewriter.getFloatAttr(type: elementTy, value: maxApf));
456 auto result = clampFloatHelper(loc, arg: args[0], min, max, rewriter);
457
458 auto clampOp = llvm::cast<tosa::ClampOp>(Val: op);
459 const auto nanMode = clampOp.getNanMode();
460
461 // NaN propagation has no meaning for non floating point types.
462 if (!isa<FloatType>(Val: elementTy))
463 return result;
464
465 // In the case of "PROPAGATE" semantics no compare and selection is
466 // required.
467 if (nanMode == "PROPAGATE")
468 return result;
469
470 // In the case of "IGNORE" semantics materialize a comparison
471 // of the current operand to the reduction which will return true for a NaN
472 // argument and then selects between the initial reduction value and the
473 // calculated result based on whether the argument is NaN or not. In pseudo
474 // code:
475 //
476 // reduce<op>(x, init):
477 // result = op(init, x)
478 // return init if x == NaN else result
479
480 // Unordered comparison of NaN against itself will always return true.
481 Value isNaN = rewriter.create<arith::CmpFOp>(
482 location: op->getLoc(), args: arith::CmpFPredicate::UNO, args: args[0], args: args[0]);
483 // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
484 // is NaN.
485 return rewriter.create<arith::SelectOp>(location: op->getLoc(), args&: isNaN, args&: min, args&: result);
486 }
487
488 if (isa<tosa::ClampOp>(Val: op) && isa<IntegerType>(Val: elementTy)) {
489 auto intTy = cast<IntegerType>(Val&: elementTy);
490 int64_t min =
491 cast<IntegerAttr>(Val: op->getAttr(name: "min_val")).getValue().getSExtValue();
492 int64_t max =
493 cast<IntegerAttr>(Val: op->getAttr(name: "max_val")).getValue().getSExtValue();
494
495 int64_t minRepresentable = std::numeric_limits<int64_t>::min();
496 int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
497 if (intTy.isUnsignedInteger()) {
498 minRepresentable = 0;
499 if (intTy.getIntOrFloatBitWidth() <= 63) {
500 maxRepresentable =
501 (int64_t)APInt::getMaxValue(numBits: intTy.getIntOrFloatBitWidth())
502 .getZExtValue();
503 }
504 } else if (intTy.getIntOrFloatBitWidth() <= 64) {
505 // Ensure that min & max fit into signed n-bit constants.
506 minRepresentable = APInt::getSignedMinValue(numBits: intTy.getIntOrFloatBitWidth())
507 .getSExtValue();
508 maxRepresentable = APInt::getSignedMaxValue(numBits: intTy.getIntOrFloatBitWidth())
509 .getSExtValue();
510 }
511 // Ensure that the bounds are representable as n-bit signed/unsigned
512 // integers.
513 min = std::max(a: min, b: minRepresentable);
514 max = std::max(a: max, b: minRepresentable);
515 min = std::min(a: min, b: maxRepresentable);
516 max = std::min(a: max, b: maxRepresentable);
517
518 auto minVal = rewriter.create<arith::ConstantIntOp>(
519 location: loc, args&: min, args: intTy.getIntOrFloatBitWidth());
520 auto maxVal = rewriter.create<arith::ConstantIntOp>(
521 location: loc, args&: max, args: intTy.getIntOrFloatBitWidth());
522 return clampIntHelper(loc, arg: args[0], min: minVal, max: maxVal, rewriter,
523 isUnsigned: intTy.isUnsignedInteger());
524 }
525
526 // tosa::SigmoidOp
527 if (isa<tosa::SigmoidOp>(Val: op) && isa<FloatType>(Val: elementTy)) {
528 auto one =
529 rewriter.create<arith::ConstantOp>(location: loc, args: FloatAttr::get(type: elementTy, value: 1));
530 auto negate = rewriter.create<arith::NegFOp>(location: loc, args&: resultTypes, args: args[0]);
531 auto exp = rewriter.create<mlir::math::ExpOp>(location: loc, args&: resultTypes, args&: negate);
532 auto added = rewriter.create<arith::AddFOp>(location: loc, args&: resultTypes, args&: exp, args&: one);
533 return rewriter.create<arith::DivFOp>(location: loc, args&: resultTypes, args&: one, args&: added);
534 }
535
536 // tosa::CastOp
537 if (isa<tosa::CastOp>(Val: op)) {
538 Type srcTy = elementTy;
539 Type dstTy = resultTypes.front();
540 if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
541 (void)rewriter.notifyMatchFailure(arg&: op, msg: "unsupported type");
542 return nullptr;
543 }
544
545 bool bitExtend =
546 srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
547
548 if (srcTy == dstTy)
549 return args.front();
550
551 if (isa<FloatType>(Val: srcTy) && isa<FloatType>(Val: dstTy) && bitExtend)
552 return rewriter.create<arith::ExtFOp>(location: loc, args&: resultTypes, args,
553 args: ArrayRef<NamedAttribute>());
554
555 if (isa<FloatType>(Val: srcTy) && isa<FloatType>(Val: dstTy) && !bitExtend)
556 return rewriter.create<arith::TruncFOp>(location: loc, args&: resultTypes, args,
557 args: ArrayRef<NamedAttribute>());
558
559 // 1-bit integers need to be treated as signless.
560 if (srcTy.isInteger(width: 1) && arith::UIToFPOp::areCastCompatible(inputs: srcTy, outputs: dstTy))
561 return rewriter.create<arith::UIToFPOp>(location: loc, args&: resultTypes, args,
562 args: ArrayRef<NamedAttribute>());
563
564 if (srcTy.isInteger(width: 1) && isa<IntegerType>(Val: dstTy) && bitExtend)
565 return rewriter.create<arith::ExtUIOp>(location: loc, args&: resultTypes, args,
566 args: ArrayRef<NamedAttribute>());
567
568 // Unsigned integers need an unrealized cast so that they can be passed
569 // to UIToFP.
570 if (srcTy.isUnsignedInteger() && isa<FloatType>(Val: dstTy)) {
571 auto unrealizedCast =
572 rewriter
573 .create<UnrealizedConversionCastOp>(
574 location: loc, args: rewriter.getIntegerType(width: srcTy.getIntOrFloatBitWidth()),
575 args: args[0])
576 .getResult(i: 0);
577 return rewriter.create<arith::UIToFPOp>(location: loc, args: resultTypes[0],
578 args&: unrealizedCast);
579 }
580
581 // All other si-to-fp conversions should be handled by SIToFP.
582 if (arith::SIToFPOp::areCastCompatible(inputs: srcTy, outputs: dstTy))
583 return rewriter.create<arith::SIToFPOp>(location: loc, args&: resultTypes, args,
584 args: ArrayRef<NamedAttribute>());
585
586 // Casting to boolean, floats need to only be checked as not-equal to zero.
587 if (isa<FloatType>(Val: srcTy) && dstTy.isInteger(width: 1)) {
588 Value zero = rewriter.create<arith::ConstantOp>(
589 location: loc, args: rewriter.getFloatAttr(type: srcTy, value: 0.0));
590 return rewriter.create<arith::CmpFOp>(location: loc, args: arith::CmpFPredicate::UNE,
591 args: args.front(), args&: zero);
592 }
593
594 if (arith::FPToSIOp::areCastCompatible(inputs: srcTy, outputs: dstTy)) {
595 auto rounded = rewriter.create<math::RoundEvenOp>(location: loc, args: args[0]);
596
597 const auto &fltSemantics = cast<FloatType>(Val&: srcTy).getFloatSemantics();
598 // Check whether neither int min nor int max can be represented in the
599 // input floating-point type due to too short exponent range.
600 if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
601 APFloat::semanticsMaxExponent(fltSemantics)) {
602 // Use cmp + select to replace infinites by int min / int max. Other
603 // integral values can be represented in the integer space.
604 auto conv = rewriter.create<arith::FPToSIOp>(location: loc, args&: dstTy, args&: rounded);
605 auto posInf = rewriter.create<arith::ConstantOp>(
606 location: loc, args: rewriter.getFloatAttr(type: getElementTypeOrSelf(type: srcTy),
607 value: APFloat::getInf(Sem: fltSemantics)));
608 auto negInf = rewriter.create<arith::ConstantOp>(
609 location: loc, args: rewriter.getFloatAttr(
610 type: getElementTypeOrSelf(type: srcTy),
611 value: APFloat::getInf(Sem: fltSemantics, /*Negative=*/true)));
612 auto overflow = rewriter.create<arith::CmpFOp>(
613 location: loc, args: arith::CmpFPredicate::UEQ, args&: rounded, args&: posInf);
614 auto underflow = rewriter.create<arith::CmpFOp>(
615 location: loc, args: arith::CmpFPredicate::UEQ, args&: rounded, args&: negInf);
616 auto intMin = rewriter.create<arith::ConstantOp>(
617 location: loc, args: rewriter.getIntegerAttr(
618 type: getElementTypeOrSelf(type: dstTy),
619 value: APInt::getSignedMinValue(numBits: dstTy.getIntOrFloatBitWidth())));
620 auto intMax = rewriter.create<arith::ConstantOp>(
621 location: loc, args: rewriter.getIntegerAttr(
622 type: getElementTypeOrSelf(type: dstTy),
623 value: APInt::getSignedMaxValue(numBits: dstTy.getIntOrFloatBitWidth())));
624 auto maxClamped =
625 rewriter.create<arith::SelectOp>(location: loc, args&: overflow, args&: intMax, args&: conv);
626 return rewriter.create<arith::SelectOp>(location: loc, args&: underflow, args&: intMin,
627 args&: maxClamped);
628 }
629
630 auto intMinFP = rewriter.create<arith::ConstantOp>(
631 location: loc, args: rewriter.getFloatAttr(
632 type: getElementTypeOrSelf(type: srcTy),
633 value: APInt::getSignedMinValue(numBits: dstTy.getIntOrFloatBitWidth())
634 .getSExtValue()));
635
636 // Check whether the mantissa has enough bits to represent int max.
637 if (cast<FloatType>(Val&: srcTy).getFPMantissaWidth() >=
638 dstTy.getIntOrFloatBitWidth() - 1) {
639 // Int min can also be represented since it is a power of two and thus
640 // consists of a single leading bit. Therefore we can clamp the input
641 // in the floating-point domain.
642
643 auto intMaxFP = rewriter.create<arith::ConstantOp>(
644 location: loc, args: rewriter.getFloatAttr(
645 type: getElementTypeOrSelf(type: srcTy),
646 value: APInt::getSignedMaxValue(numBits: dstTy.getIntOrFloatBitWidth())
647 .getSExtValue()));
648
649 Value clamped =
650 clampFloatHelper(loc, arg: rounded, min: intMinFP, max: intMaxFP, rewriter);
651 return rewriter.create<arith::FPToSIOp>(location: loc, args&: dstTy, args&: clamped);
652 }
653
654 // Due to earlier check we know exponant range is big enough to represent
655 // int min. We can therefore rely on int max + 1 being representable as
656 // well because it's just int min with a positive sign. So clamp the min
657 // value and compare against that to select the max int value if needed.
658 auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
659 location: loc, args: rewriter.getFloatAttr(
660 type: getElementTypeOrSelf(type: srcTy),
661 value: static_cast<double>(
662 APInt::getSignedMaxValue(numBits: dstTy.getIntOrFloatBitWidth())
663 .getSExtValue()) +
664 1.0f));
665
666 auto intMax = rewriter.create<arith::ConstantOp>(
667 location: loc, args: rewriter.getIntegerAttr(
668 type: getElementTypeOrSelf(type: dstTy),
669 value: APInt::getSignedMaxValue(numBits: dstTy.getIntOrFloatBitWidth())));
670 auto minClampedFP =
671 rewriter.create<arith::MaximumFOp>(location: loc, args&: rounded, args&: intMinFP);
672 auto minClamped =
673 rewriter.create<arith::FPToSIOp>(location: loc, args&: dstTy, args&: minClampedFP);
674 auto overflow = rewriter.create<arith::CmpFOp>(
675 location: loc, args: arith::CmpFPredicate::UGE, args&: rounded, args&: intMaxPlusOneFP);
676 return rewriter.create<arith::SelectOp>(location: loc, args&: overflow, args&: intMax,
677 args&: minClamped);
678 }
679
680 // Casting to boolean, integers need to only be checked as not-equal to
681 // zero.
682 if (isa<IntegerType>(Val: srcTy) && dstTy.isInteger(width: 1)) {
683 Value zero = rewriter.create<arith::ConstantIntOp>(
684 location: loc, args: 0, args: srcTy.getIntOrFloatBitWidth());
685 return rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::ne,
686 args: args.front(), args&: zero);
687 }
688
689 if (isa<IntegerType>(Val: srcTy) && isa<IntegerType>(Val: dstTy) && bitExtend)
690 return rewriter.create<arith::ExtSIOp>(location: loc, args&: resultTypes, args,
691 args: ArrayRef<NamedAttribute>());
692
693 if (isa<IntegerType>(Val: srcTy) && isa<IntegerType>(Val: dstTy) && !bitExtend) {
694 return rewriter.create<arith::TruncIOp>(location: loc, args&: dstTy, args: args[0]);
695 }
696 }
697
698 (void)rewriter.notifyMatchFailure(
699 arg&: op, msg: "unhandled op for linalg body calculation for elementwise op");
700 return nullptr;
701}
702
703using IndexPool = DenseMap<int64_t, Value>;
704
705// Emit an 'arith.constant' op for the given index if it has not been created
706// yet, or return an existing constant. This will prevent an excessive creation
707// of redundant constants, easing readability of emitted code for unit tests.
708static Value createIndex(PatternRewriter &rewriter, Location loc,
709 IndexPool &indexPool, int64_t index) {
710 auto [it, inserted] = indexPool.try_emplace(Key: index);
711 if (inserted)
712 it->second =
713 rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getIndexAttr(value: index));
714 return it->second;
715}
716
717static Value getTensorDim(PatternRewriter &rewriter, Location loc,
718 IndexPool &indexPool, Value tensor, int64_t index) {
719 auto indexValue = createIndex(rewriter, loc, indexPool, index);
720 return rewriter.create<tensor::DimOp>(location: loc, args&: tensor, args&: indexValue).getResult();
721}
722
723static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc,
724 IndexPool &indexPool, Value tensor,
725 int64_t index) {
726 auto shapedType = dyn_cast<ShapedType>(Val: tensor.getType());
727 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
728 assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
729 if (shapedType.isDynamicDim(idx: index))
730 return getTensorDim(rewriter, loc, indexPool, tensor, index);
731 return rewriter.getIndexAttr(value: shapedType.getDimSize(idx: index));
732}
733
734static bool operandsAndResultsRanked(Operation *operation) {
735 auto isRanked = [](Value value) {
736 return isa<RankedTensorType>(Val: value.getType());
737 };
738 return llvm::all_of(Range: operation->getOperands(), P: isRanked) &&
739 llvm::all_of(Range: operation->getResults(), P: isRanked);
740}
741
742// Compute the runtime dimension size for dimension 'dim' of the output by
743// inspecting input 'operands', all of which are expected to have the same rank.
744// This function returns a pair {targetSize, masterOperand}.
745//
746// The runtime size of the output dimension is returned either as a statically
747// computed attribute or as a runtime SSA value.
748//
749// If the target size was inferred directly from one dominating operand, that
750// operand is returned in 'masterOperand'. If the target size is inferred from
751// multiple operands, 'masterOperand' is set to nullptr.
752static std::pair<OpFoldResult, Value>
753computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
754 ValueRange operands, int64_t dim) {
755 // If any input operand contains a static size greater than 1 for this
756 // dimension, that is the target size. An occurrence of an additional static
757 // dimension greater than 1 with a different value is undefined behavior.
758 for (auto operand : operands) {
759 auto size = cast<RankedTensorType>(Val: operand.getType()).getDimSize(idx: dim);
760 if (ShapedType::isStatic(dValue: size) && size > 1)
761 return {rewriter.getIndexAttr(value: size), operand};
762 }
763
764 // Filter operands with dynamic dimension
765 auto operandsWithDynamicDim =
766 llvm::filter_to_vector(C&: operands, Pred: [&](Value operand) {
767 return cast<RankedTensorType>(Val: operand.getType()).isDynamicDim(idx: dim);
768 });
769
770 // If no operand has a dynamic dimension, it means all sizes were 1
771 if (operandsWithDynamicDim.empty())
772 return {rewriter.getIndexAttr(value: 1), operands.front()};
773
774 // Emit code that computes the runtime size for this dimension. If there is
775 // only one operand with a dynamic dimension, it is considered the master
776 // operand that determines the runtime size of the output dimension.
777 auto targetSize =
778 getTensorDim(rewriter, loc, indexPool, tensor: operandsWithDynamicDim[0], index: dim);
779 if (operandsWithDynamicDim.size() == 1)
780 return {targetSize, operandsWithDynamicDim[0]};
781
782 // Calculate maximum size among all dynamic dimensions
783 for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
784 auto nextSize =
785 getTensorDim(rewriter, loc, indexPool, tensor: operandsWithDynamicDim[i], index: dim);
786 targetSize = rewriter.create<arith::MaxUIOp>(location: loc, args&: targetSize, args&: nextSize);
787 }
788 return {targetSize, nullptr};
789}
790
791// Compute the runtime output size for all dimensions. This function returns
792// a pair {targetShape, masterOperands}.
793static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
794computeTargetShape(PatternRewriter &rewriter, Location loc,
795 IndexPool &indexPool, ValueRange operands) {
796 assert(!operands.empty());
797 auto rank = cast<RankedTensorType>(Val: operands.front().getType()).getRank();
798 SmallVector<OpFoldResult> targetShape;
799 SmallVector<Value> masterOperands;
800 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: rank)) {
801 auto [targetSize, masterOperand] =
802 computeTargetSize(rewriter, loc, indexPool, operands, dim);
803 targetShape.push_back(Elt: targetSize);
804 masterOperands.push_back(Elt: masterOperand);
805 }
806 return {targetShape, masterOperands};
807}
808
809static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
810 IndexPool &indexPool, Value operand,
811 int64_t dim, OpFoldResult targetSize,
812 Value masterOperand) {
813 // Nothing to do if this is a static dimension
814 auto rankedTensorType = cast<RankedTensorType>(Val: operand.getType());
815 if (!rankedTensorType.isDynamicDim(idx: dim))
816 return operand;
817
818 // If the target size for this dimension was directly inferred by only taking
819 // this operand into account, there is no need to broadcast. This is an
820 // optimization that will prevent redundant control flow, and constitutes the
821 // main motivation for tracking "master operands".
822 if (operand == masterOperand)
823 return operand;
824
825 // Affine maps for 'linalg.generic' op
826 auto rank = rankedTensorType.getRank();
827 SmallVector<AffineExpr> affineExprs;
828 for (auto index : llvm::seq<int64_t>(Begin: 0, End: rank)) {
829 auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(constant: 0)
830 : rewriter.getAffineDimExpr(position: index);
831 affineExprs.push_back(Elt: affineExpr);
832 }
833 auto broadcastAffineMap =
834 AffineMap::get(dimCount: rank, symbolCount: 0, results: affineExprs, context: rewriter.getContext());
835 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
836 SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
837
838 // Check if broadcast is necessary
839 auto one = createIndex(rewriter, loc, indexPool, index: 1);
840 auto runtimeSize = getTensorDim(rewriter, loc, indexPool, tensor: operand, index: dim);
841 auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
842 location: loc, args: arith::CmpIPredicate::eq, args&: runtimeSize, args&: one);
843
844 // Emit 'then' region of 'scf.if'
845 auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
846 // It is not safe to cache constants across regions.
847 // New constants could potentially violate dominance requirements.
848 IndexPool localPool;
849
850 // Emit 'tensor.empty' op
851 SmallVector<OpFoldResult> outputTensorShape;
852 for (auto index : llvm::seq<int64_t>(Begin: 0, End: rank)) {
853 auto size = index == dim ? targetSize
854 : getOrFoldTensorDim(rewriter, loc, indexPool&: localPool,
855 tensor: operand, index);
856 outputTensorShape.push_back(Elt: size);
857 }
858 Value outputTensor = opBuilder.create<tensor::EmptyOp>(
859 location: loc, args&: outputTensorShape, args: rankedTensorType.getElementType());
860
861 // Emit 'linalg.generic' op
862 auto resultTensor =
863 opBuilder
864 .create<linalg::GenericOp>(
865 location: loc, args: outputTensor.getType(), args&: operand, args&: outputTensor, args&: affineMaps,
866 args: getNParallelLoopsAttrs(nParallelLoops: rank),
867 args: [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
868 // Emit 'linalg.yield' op
869 opBuilder.create<linalg::YieldOp>(location: loc, args: blockArgs.front());
870 })
871 .getResult(i: 0);
872
873 // Cast to original operand type if necessary
874 auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
875 location: loc, args: operand.getType(), args&: resultTensor);
876
877 // Emit 'scf.yield' op
878 opBuilder.create<scf::YieldOp>(location: loc, args&: castResultTensor);
879 };
880
881 // Emit 'else' region of 'scf.if'
882 auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
883 opBuilder.create<scf::YieldOp>(location: loc, args&: operand);
884 };
885
886 // Emit 'scf.if' op
887 auto ifOp = rewriter.create<scf::IfOp>(location: loc, args&: broadcastNecessary,
888 args&: emitThenRegion, args&: emitElseRegion);
889 return ifOp.getResult(i: 0);
890}
891
892static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
893 IndexPool &indexPool, Value operand,
894 ArrayRef<OpFoldResult> targetShape,
895 ArrayRef<Value> masterOperands) {
896 int64_t rank = cast<RankedTensorType>(Val: operand.getType()).getRank();
897 assert((int64_t)targetShape.size() == rank);
898 assert((int64_t)masterOperands.size() == rank);
899 for (auto index : llvm::seq<int64_t>(Begin: 0, End: rank))
900 operand =
901 broadcastDynamicDimension(rewriter, loc, indexPool, operand, dim: index,
902 targetSize: targetShape[index], masterOperand: masterOperands[index]);
903 return operand;
904}
905
906static SmallVector<Value>
907broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
908 IndexPool &indexPool, ValueRange operands,
909 ArrayRef<OpFoldResult> targetShape,
910 ArrayRef<Value> masterOperands) {
911 // No need to broadcast for unary operations
912 if (operands.size() == 1)
913 return operands;
914
915 // Broadcast dynamic dimensions operand by operand
916 return llvm::map_to_vector(C&: operands, F: [&](Value operand) {
917 return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
918 targetShape, masterOperands);
919 });
920}
921
922static LogicalResult
923emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
924 Operation *operation, ValueRange operands,
925 ArrayRef<OpFoldResult> targetShape,
926 const TypeConverter &converter) {
927 // Generate output tensor
928 auto resultType = cast_or_null<RankedTensorType>(
929 Val: converter.convertType(t: operation->getResultTypes().front()));
930 if (!resultType) {
931 return rewriter.notifyMatchFailure(arg&: operation, msg: "failed to convert type");
932 }
933 Value outputTensor = rewriter.create<tensor::EmptyOp>(
934 location: loc, args&: targetShape, args: resultType.getElementType());
935
936 // Create affine maps. Input affine maps broadcast static dimensions of size
937 // 1. The output affine map is an identity map.
938 //
939 auto rank = resultType.getRank();
940 auto affineMaps = llvm::map_to_vector(C&: operands, F: [&](Value operand) {
941 auto shape = cast<ShapedType>(Val: operand.getType()).getShape();
942 SmallVector<AffineExpr> affineExprs;
943 for (auto it : llvm::enumerate(First&: shape)) {
944 // Prefer producting identity maps whenever possible (i.e. no broadcasting
945 // needed) because some transforms (like reshape folding)
946 // do not support affine constant exprs.
947 bool requiresBroadcast =
948 (it.value() == 1 && resultType.getDimSize(idx: it.index()) != 1);
949 auto affineExpr = requiresBroadcast
950 ? rewriter.getAffineConstantExpr(constant: 0)
951 : rewriter.getAffineDimExpr(position: it.index());
952 affineExprs.push_back(Elt: affineExpr);
953 }
954 return AffineMap::get(dimCount: rank, symbolCount: 0, results: affineExprs, context: rewriter.getContext());
955 });
956 affineMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank));
957
958 // Emit 'linalg.generic' op
959 bool encounteredError = false;
960 auto linalgOp = rewriter.create<linalg::GenericOp>(
961 location: loc, args: outputTensor.getType(), args&: operands, args&: outputTensor, args&: affineMaps,
962 args: getNParallelLoopsAttrs(nParallelLoops: rank),
963 args: [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
964 Value opResult = createLinalgBodyCalculationForElementwiseOp(
965 op: operation, args: blockArgs.take_front(n: operation->getNumOperands()),
966 resultTypes: {resultType.getElementType()}, rewriter);
967 if (!opResult) {
968 encounteredError = true;
969 return;
970 }
971 opBuilder.create<linalg::YieldOp>(location: loc, args&: opResult);
972 });
973 if (encounteredError)
974 return rewriter.notifyMatchFailure(
975 arg&: operation, msg: "unable to create linalg.generic body for elementwise op");
976
977 // Cast 'linalg.generic' result into original result type if needed
978 auto castResult = rewriter.createOrFold<tensor::CastOp>(
979 location: loc, args&: resultType, args: linalgOp->getResult(idx: 0));
980 rewriter.replaceOp(op: operation, newValues: castResult);
981 return success();
982}
983
984static ValueRange getBroadcastableOperands(Operation *operation,
985 ValueRange operands) {
986 // Shift cannot broadcast
987 if (isa<tosa::MulOp>(Val: operation))
988 return operands.take_front(n: 2);
989 // Input1_zp and output_zp cannot broadcast
990 if (isa<tosa::NegateOp>(Val: operation))
991 return operands.take_front(n: 1);
992 return operands;
993}
994
995static LogicalResult
996elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
997 ConversionPatternRewriter &rewriter,
998 const TypeConverter &converter) {
999
1000 // Collect op properties
1001 assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
1002 assert(operation->getNumOperands() >= 1 &&
1003 "elementwise op expects at least 1 operand");
1004 if (!operandsAndResultsRanked(operation))
1005 return rewriter.notifyMatchFailure(arg&: operation,
1006 msg: "Unranked tensors not supported");
1007
1008 // Lower operation
1009 IndexPool indexPool;
1010 auto loc = operation->getLoc();
1011 auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
1012 auto [targetShape, masterOperands] =
1013 computeTargetShape(rewriter, loc, indexPool, operands: operandsToBroadcast);
1014 auto broadcastOperands =
1015 broadcastDynamicDimensions(rewriter, loc, indexPool, operands: operandsToBroadcast,
1016 targetShape, masterOperands);
1017 return emitElementwiseComputation(rewriter, loc, operation, operands: broadcastOperands,
1018 targetShape, converter);
1019}
1020
1021// Returns the constant initial value for a given reduction operation. The
1022// attribute type varies depending on the element type required.
1023static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
1024 PatternRewriter &rewriter) {
1025 if (isa<tosa::ReduceSumOp>(Val: op) && isa<FloatType>(Val: elementTy))
1026 return rewriter.getFloatAttr(type: elementTy, value: 0.0);
1027
1028 if (isa<tosa::ReduceSumOp>(Val: op) && isa<IntegerType>(Val: elementTy))
1029 return rewriter.getIntegerAttr(type: elementTy, value: 0);
1030
1031 if (isa<tosa::ReduceProductOp>(Val: op) && isa<FloatType>(Val: elementTy))
1032 return rewriter.getFloatAttr(type: elementTy, value: 1.0);
1033
1034 if (isa<tosa::ReduceProductOp>(Val: op) && isa<IntegerType>(Val: elementTy))
1035 return rewriter.getIntegerAttr(type: elementTy, value: 1);
1036
1037 if (isa<tosa::ReduceMinOp>(Val: op) && isa<FloatType>(Val: elementTy))
1038 return rewriter.getFloatAttr(
1039 type: elementTy, value: APFloat::getLargest(
1040 Sem: cast<FloatType>(Val&: elementTy).getFloatSemantics(), Negative: false));
1041
1042 if (isa<tosa::ReduceMinOp>(Val: op) && isa<IntegerType>(Val: elementTy))
1043 return rewriter.getIntegerAttr(
1044 type: elementTy, value: APInt::getSignedMaxValue(numBits: elementTy.getIntOrFloatBitWidth()));
1045
1046 if (isa<tosa::ReduceMaxOp>(Val: op) && isa<FloatType>(Val: elementTy))
1047 return rewriter.getFloatAttr(
1048 type: elementTy, value: APFloat::getLargest(
1049 Sem: cast<FloatType>(Val&: elementTy).getFloatSemantics(), Negative: true));
1050
1051 if (isa<tosa::ReduceMaxOp>(Val: op) && isa<IntegerType>(Val: elementTy))
1052 return rewriter.getIntegerAttr(
1053 type: elementTy, value: APInt::getSignedMinValue(numBits: elementTy.getIntOrFloatBitWidth()));
1054
1055 if (isa<tosa::ReduceAllOp>(Val: op) && elementTy.isInteger(width: 1))
1056 return rewriter.getIntegerAttr(type: elementTy, value: APInt::getAllOnes(numBits: 1));
1057
1058 if (isa<tosa::ReduceAnyOp>(Val: op) && elementTy.isInteger(width: 1))
1059 return rewriter.getIntegerAttr(type: elementTy, value: APInt::getZero(numBits: 1));
1060
1061 if (isa<tosa::ArgMaxOp>(Val: op) && isa<FloatType>(Val: elementTy))
1062 return rewriter.getFloatAttr(
1063 type: elementTy, value: APFloat::getLargest(
1064 Sem: cast<FloatType>(Val&: elementTy).getFloatSemantics(), Negative: true));
1065
1066 if (isa<tosa::ArgMaxOp>(Val: op) && isa<IntegerType>(Val: elementTy))
1067 return rewriter.getIntegerAttr(
1068 type: elementTy, value: APInt::getSignedMinValue(numBits: elementTy.getIntOrFloatBitWidth()));
1069
1070 return {};
1071}
1072
1073// Creates the body calculation for a reduction. The operations vary depending
1074// on the input type.
1075static Value createLinalgBodyCalculationForReduceOp(Operation *op,
1076 ValueRange args,
1077 Type elementTy,
1078 PatternRewriter &rewriter) {
1079 Location loc = op->getLoc();
1080 if (isa<tosa::ReduceSumOp>(Val: op) && isa<FloatType>(Val: elementTy)) {
1081 return rewriter.create<arith::AddFOp>(location: loc, args);
1082 }
1083
1084 if (isa<tosa::ReduceSumOp>(Val: op) && isa<IntegerType>(Val: elementTy)) {
1085 return rewriter.create<arith::AddIOp>(location: loc, args);
1086 }
1087
1088 if (isa<tosa::ReduceProductOp>(Val: op) && isa<FloatType>(Val: elementTy)) {
1089 return rewriter.create<arith::MulFOp>(location: loc, args);
1090 }
1091
1092 if (isa<tosa::ReduceProductOp>(Val: op) && isa<IntegerType>(Val: elementTy)) {
1093 return rewriter.create<arith::MulIOp>(location: loc, args);
1094 }
1095
1096 if (isa<tosa::ReduceMinOp>(Val: op) && isa<FloatType>(Val: elementTy)) {
1097 return rewriter.create<arith::MinimumFOp>(location: loc, args: args[0], args: args[1]);
1098 }
1099
1100 if (isa<tosa::ReduceMinOp>(Val: op) && isa<IntegerType>(Val: elementTy)) {
1101 return rewriter.create<arith::MinSIOp>(location: loc, args: args[0], args: args[1]);
1102 }
1103
1104 if (isa<tosa::ReduceMaxOp>(Val: op) && isa<FloatType>(Val: elementTy)) {
1105 return rewriter.create<arith::MaximumFOp>(location: loc, args: args[0], args: args[1]);
1106 }
1107
1108 if (isa<tosa::ReduceMaxOp>(Val: op) && isa<IntegerType>(Val: elementTy)) {
1109 return rewriter.create<arith::MaxSIOp>(location: loc, args: args[0], args: args[1]);
1110 }
1111
1112 if (isa<tosa::ReduceAllOp>(Val: op) && elementTy.isInteger(width: 1))
1113 return rewriter.create<arith::AndIOp>(location: loc, args);
1114
1115 if (isa<tosa::ReduceAnyOp>(Val: op) && elementTy.isInteger(width: 1))
1116 return rewriter.create<arith::OrIOp>(location: loc, args);
1117
1118 return {};
1119}
1120
1121// Performs the match and rewrite for reduction operations. This includes
1122// declaring a correctly sized initial value, and the linalg.generic operation
1123// that reduces across the specified axis.
1124template <typename OpTy>
1125static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1126 PatternRewriter &rewriter) {
1127 auto loc = op->getLoc();
1128 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1129 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1130 if (!inputTy || !resultTy)
1131 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1132
1133 auto elementTy = resultTy.getElementType();
1134 Value input = op->getOperand(0);
1135
1136 SmallVector<int64_t> reduceShape;
1137 SmallVector<Value> dynDims;
1138 for (unsigned i = 0; i < inputTy.getRank(); i++) {
1139 if (axis != i) {
1140 reduceShape.push_back(Elt: inputTy.getDimSize(i));
1141 if (inputTy.isDynamicDim(i))
1142 dynDims.push_back(Elt: rewriter.create<tensor::DimOp>(loc, input, i));
1143 }
1144 }
1145
1146 SmallVector<Value> inputs, outputs;
1147 inputs.push_back(Elt: input);
1148
1149 // First fill the output buffer with the init value.
1150 auto emptyTensor =
1151 rewriter
1152 .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1153 dynDims)
1154 .getResult();
1155
1156 auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1157 if (!fillValueAttr)
1158 return rewriter.notifyMatchFailure(
1159 op, "No initial value found for reduction operation");
1160
1161 auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1162 auto filledTensor = rewriter
1163 .create<linalg::FillOp>(loc, ValueRange{fillValue},
1164 ValueRange{emptyTensor})
1165 .result();
1166 outputs.push_back(Elt: filledTensor);
1167
1168 bool isNanIgnoreMode = false;
1169 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1170 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1171 // NaN propagation has no meaning for non floating point types.
1172 if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
1173 isNanIgnoreMode = true;
1174 // Because the TOSA spec requires the result be NaN iff all elements in
1175 // the reduction are NaN we can't simply perform a compare and select.
1176 // Additionally we have to keep track of whether we've seen any non-NaN
1177 // values and then do a final select based on this predicate.
1178 auto trueAttr = rewriter.getBoolAttr(value: true);
1179 auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
1180 auto emptyBoolTensor =
1181 rewriter
1182 .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
1183 dynDims)
1184 .getResult();
1185 auto allResultsNaNTensor =
1186 rewriter
1187 .create<linalg::FillOp>(loc, ValueRange{trueValue},
1188 ValueRange{emptyBoolTensor})
1189 .result();
1190 // Note that because the linalg::ReduceOp has two variadic arguments
1191 // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1192 // need to have the same number of inputs and outputs.
1193 //
1194 // The second input isn't actually used anywhere since the value used to
1195 // update the NaN flag is calculated inside the body of the reduction and
1196 // then used to update an out value.
1197 // In order to satisfy type constraints we just pass another copy of the
1198 // input here.
1199 inputs.push_back(Elt: input);
1200 outputs.push_back(Elt: allResultsNaNTensor);
1201 }
1202 }
1203
1204 bool didEncounterError = false;
1205 linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
1206 loc, inputs, outputs, axis,
1207 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1208 std::array<Value, 2> binaryArgs{
1209 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1210 auto result = createLinalgBodyCalculationForReduceOp(
1211 op, binaryArgs, elementTy, rewriter);
1212 if (result)
1213 didEncounterError = true;
1214
1215 SmallVector<Value> resultsToYield;
1216 if (isNanIgnoreMode) {
1217 auto inputValue = blockArgs[0];
1218 auto initialValue = blockArgs[2];
1219 auto oldAllResultsNanFlagValue = blockArgs[3];
1220
1221 // Unordered comparison of NaN against itself will always return true.
1222 Value isNaN = nestedBuilder.create<arith::CmpFOp>(
1223 op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
1224 // If we've encountered a NaN, take the non-NaN value.
1225 auto selectOp = nestedBuilder.create<arith::SelectOp>(
1226 op->getLoc(), isNaN, initialValue, result);
1227 // Update the flag which keeps track of whether we have seen a non-NaN
1228 // value.
1229 auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
1230 op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1231 resultsToYield.push_back(Elt: selectOp);
1232 resultsToYield.push_back(Elt: newAllResultsNanFlagValue);
1233 } else {
1234 resultsToYield.push_back(Elt: result);
1235 }
1236 nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
1237 });
1238
1239 if (!didEncounterError)
1240 return rewriter.notifyMatchFailure(
1241 op, "unable to create linalg.generic body for reduce op");
1242
1243 if (isNanIgnoreMode) {
1244 // Materialize a check to see whether we encountered any non-NaN values, if
1245 // we didn't we need to select a tensor of NaNs since the result will just
1246 // be the initial identity value propagated through all the compares and
1247 // selects inside the reduction.
1248
1249 // Create a tensor full of NaNs.
1250 auto nanValueAttr = rewriter.getFloatAttr(
1251 elementTy,
1252 APFloat::getNaN(Sem: cast<FloatType>(elementTy).getFloatSemantics(), Negative: false));
1253 auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
1254 auto emptyNanTensor =
1255 rewriter
1256 .create<tensor::EmptyOp>(loc, reduceShape,
1257 resultTy.getElementType(), dynDims)
1258 .getResult();
1259 auto nanFilledTensor =
1260 rewriter
1261 .create<linalg::FillOp>(loc, ValueRange{nanValue},
1262 ValueRange{emptyNanTensor})
1263 .result();
1264
1265 // Create an empty tensor, non need to fill this since it will be
1266 // overwritten by the select.
1267 auto finalEmptyTensor =
1268 rewriter
1269 .create<tensor::EmptyOp>(loc, reduceShape,
1270 resultTy.getElementType(), dynDims)
1271 .getResult();
1272
1273 // Do a selection between the tensors akin to:
1274 // result = NaN if "all results NaN" else result.
1275 SmallVector<Value> ins, outs;
1276 ins.push_back(Elt: linalgOp->getOpResult(idx: 1));
1277 ins.push_back(Elt: nanFilledTensor);
1278 ins.push_back(Elt: linalgOp->getResult(idx: 0));
1279 outs.push_back(Elt: finalEmptyTensor);
1280 auto linalgSelect =
1281 rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
1282 linalgOp = linalgSelect;
1283 }
1284
1285 SmallVector<ReassociationExprs, 4> reassociationMap;
1286 uint64_t expandInputRank =
1287 cast<ShapedType>(Val: linalgOp->getResults()[0].getType()).getRank();
1288 reassociationMap.resize(N: expandInputRank);
1289
1290 for (uint64_t i = 0; i < expandInputRank; i++) {
1291 int32_t dimToPush = i > axis ? i + 1 : i;
1292 reassociationMap[i].push_back(Elt: rewriter.getAffineDimExpr(position: dimToPush));
1293 }
1294
1295 if (expandInputRank != 0) {
1296 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1297 reassociationMap[expandedDim].push_back(
1298 Elt: rewriter.getAffineDimExpr(position: expandedDim + 1));
1299 }
1300
1301 // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1302 // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1303 // not have access to such information. This matters when handling dynamically
1304 // sized tensors.
1305 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1306 op, resultTy, linalgOp->getResults()[0], reassociationMap);
1307 return success();
1308}
1309
1310namespace {
1311
1312template <typename SrcOp>
1313class PointwiseConverter : public OpConversionPattern<SrcOp> {
1314public:
1315 using OpConversionPattern<SrcOp>::OpConversionPattern;
1316 using typename OpConversionPattern<SrcOp>::OpAdaptor;
1317
1318 LogicalResult
1319 matchAndRewrite(SrcOp op, OpAdaptor operands,
1320 ConversionPatternRewriter &rewriter) const final {
1321 return elementwiseMatchAndRewriteHelper(
1322 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1323 }
1324};
1325
1326class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1327public:
1328 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1329
1330 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1331 PatternRewriter &rewriter) const final {
1332 auto loc = op.getLoc();
1333 auto input = op.getInput();
1334 auto inputTy = cast<ShapedType>(Val: op.getInput().getType());
1335 auto outputTy = cast<ShapedType>(Val: op.getOutput().getType());
1336 unsigned rank = inputTy.getRank();
1337
1338 // This is an illegal configuration. terminate and log an error
1339 if (op.getRoundingMode() == "INEXACT_ROUND")
1340 return rewriter.notifyMatchFailure(
1341 arg&: op, msg: "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1342 "currently supported");
1343 if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
1344 return rewriter.notifyMatchFailure(
1345 arg&: op, msg: "tosa.rescale requires scale32 for double_round to be true");
1346
1347 if (!isa<IntegerType>(Val: inputTy.getElementType()))
1348 return rewriter.notifyMatchFailure(arg&: op, msg: "only support integer type");
1349
1350 SmallVector<Value> dynDims;
1351 for (int i = 0; i < outputTy.getRank(); i++) {
1352 if (outputTy.isDynamicDim(idx: i)) {
1353 dynDims.push_back(Elt: rewriter.create<tensor::DimOp>(location: loc, args&: input, args&: i));
1354 }
1355 }
1356
1357 // The shift and multiplier values.
1358 DenseElementsAttr shiftElems;
1359 if (!matchPattern(value: op.getShift(), pattern: m_Constant(bind_value: &shiftElems)))
1360 return rewriter.notifyMatchFailure(
1361 arg&: op, msg: "tosa.rescale requires constant shift input values");
1362
1363 DenseElementsAttr multiplierElems;
1364 if (!matchPattern(value: op.getMultiplier(), pattern: m_Constant(bind_value: &multiplierElems)))
1365 return rewriter.notifyMatchFailure(
1366 arg&: op, msg: "tosa.rescale requires constant multiplier input values");
1367
1368 llvm::SmallVector<int8_t> shiftValues =
1369 llvm::to_vector(Range: shiftElems.getValues<int8_t>());
1370 // explicit cast is required here
1371 llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
1372 Range: llvm::map_range(C: multiplierElems.getValues<IntegerAttr>(),
1373 F: [](IntegerAttr attr) -> int32_t {
1374 return static_cast<int32_t>(attr.getInt());
1375 }));
1376
1377 // If we shift by more than the bitwidth, this just sets to 0.
1378 for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1379 if (shiftValues[i] > 63) {
1380 shiftValues[i] = 0;
1381 multiplierValues[i] = 0;
1382 }
1383 }
1384
1385 // Double round only occurs if shift is greater than 31, check that this
1386 // is ever true.
1387
1388 bool doubleRound =
1389 op.getRoundingMode() == "DOUBLE_ROUND" &&
1390 llvm::any_of(Range&: shiftValues, P: [](int32_t v) { return v > 31; });
1391 StringAttr roundingMode = doubleRound
1392 ? rewriter.getStringAttr(bytes: "DOUBLE_ROUND")
1393 : rewriter.getStringAttr(bytes: "SINGLE_ROUND");
1394
1395 SmallVector<AffineMap> indexingMaps = {
1396 rewriter.getMultiDimIdentityMap(rank)};
1397 SmallVector<Value, 4> genericInputs = {input};
1398
1399 // If we are rescaling per-channel then we need to store the multiplier
1400 // values in a buffer.
1401 Value multiplierConstant;
1402 int64_t multiplierArg = 0;
1403 if (multiplierValues.size() == 1) {
1404 multiplierConstant = rewriter.create<arith::ConstantOp>(
1405 location: loc, args: rewriter.getI32IntegerAttr(value: multiplierValues.front()));
1406 } else {
1407 SmallVector<AffineExpr, 2> multiplierExprs{
1408 rewriter.getAffineDimExpr(position: rank - 1)};
1409 auto multiplierType =
1410 RankedTensorType::get(shape: {static_cast<int64_t>(multiplierValues.size())},
1411 elementType: rewriter.getI32Type());
1412 genericInputs.push_back(Elt: rewriter.create<arith::ConstantOp>(
1413 location: loc, args: DenseIntElementsAttr::get(type: multiplierType, arg&: multiplierValues)));
1414
1415 indexingMaps.push_back(Elt: AffineMap::get(/*dimCount=*/rank,
1416 /*symbolCount=*/0, results: multiplierExprs,
1417 context: rewriter.getContext()));
1418
1419 multiplierArg = indexingMaps.size() - 1;
1420 }
1421
1422 // If we are rescaling per-channel then we need to store the shift
1423 // values in a buffer.
1424 Value shiftConstant;
1425 int64_t shiftArg = 0;
1426 if (shiftValues.size() == 1) {
1427 shiftConstant = rewriter.create<arith::ConstantOp>(
1428 location: loc, args: rewriter.getI8IntegerAttr(value: shiftValues.front()));
1429 } else {
1430 SmallVector<AffineExpr, 2> shiftExprs = {
1431 rewriter.getAffineDimExpr(position: rank - 1)};
1432 auto shiftType =
1433 RankedTensorType::get(shape: {static_cast<int64_t>(shiftValues.size())},
1434 elementType: rewriter.getIntegerType(width: 8));
1435 genericInputs.push_back(Elt: rewriter.create<arith::ConstantOp>(
1436 location: loc, args: DenseIntElementsAttr::get(type: shiftType, arg&: shiftValues)));
1437 indexingMaps.push_back(Elt: AffineMap::get(/*dimCount=*/rank,
1438 /*symbolCount=*/0, results: shiftExprs,
1439 context: rewriter.getContext()));
1440 shiftArg = indexingMaps.size() - 1;
1441 }
1442
1443 // Indexing maps for output values.
1444 indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank));
1445
1446 // Construct the indexing maps needed for linalg.generic ops.
1447 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1448 location: loc, args: outputTy.getShape(), args: outputTy.getElementType(),
1449 args: ArrayRef<Value>({dynDims}));
1450
1451 auto linalgOp = rewriter.create<linalg::GenericOp>(
1452 location: loc, args&: outputTy, args&: genericInputs, args: ValueRange{emptyTensor}, args&: indexingMaps,
1453 args: getNParallelLoopsAttrs(nParallelLoops: rank),
1454 args: [&](OpBuilder &nestedBuilder, Location nestedLoc,
1455 ValueRange blockArgs) {
1456 Value value = blockArgs[0];
1457 Type valueTy = value.getType();
1458
1459 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1460 if (failed(Result: maybeIZp)) {
1461 (void)rewriter.notifyMatchFailure(
1462 arg&: op, msg: "input zero point cannot be statically determined");
1463 return;
1464 }
1465
1466 const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
1467 // Extend zeropoint for sub-32bits widths.
1468 const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1469 auto inputZp = nestedBuilder.create<arith::ConstantOp>(
1470 location: loc, args: IntegerAttr::get(type: rewriter.getIntegerType(width: inAttrBitwidth),
1471 value: *maybeIZp));
1472
1473 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1474 if (failed(Result: maybeOZp)) {
1475 (void)rewriter.notifyMatchFailure(
1476 arg&: op, msg: "output zero point cannot be statically determined");
1477 return;
1478 };
1479
1480 IntegerType outIntType =
1481 cast<IntegerType>(Val: blockArgs.back().getType());
1482 unsigned outBitWidth = outIntType.getWidth();
1483 const int32_t outAttrBitwidth = 32;
1484 assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1485 auto outputZp = nestedBuilder.create<arith::ConstantOp>(
1486 location: loc, args: IntegerAttr::get(type: rewriter.getIntegerType(width: outAttrBitwidth),
1487 value: *maybeOZp));
1488
1489 Value multiplier = multiplierConstant ? multiplierConstant
1490 : blockArgs[multiplierArg];
1491 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1492
1493 if (valueTy.isUnsignedInteger()) {
1494 value = nestedBuilder
1495 .create<UnrealizedConversionCastOp>(
1496 location: nestedLoc,
1497 args: nestedBuilder.getIntegerType(
1498 width: valueTy.getIntOrFloatBitWidth()),
1499 args&: value)
1500 .getResult(i: 0);
1501 }
1502 if (valueTy.getIntOrFloatBitWidth() < 32) {
1503 if (op.getInputUnsigned()) {
1504 value = nestedBuilder.create<arith::ExtUIOp>(
1505 location: nestedLoc, args: nestedBuilder.getI32Type(), args&: value);
1506 } else {
1507 value = nestedBuilder.create<arith::ExtSIOp>(
1508 location: nestedLoc, args: nestedBuilder.getI32Type(), args&: value);
1509 }
1510 }
1511
1512 value =
1513 nestedBuilder.create<arith::SubIOp>(location: nestedLoc, args&: value, args&: inputZp);
1514
1515 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1516 location: loc, args: nestedBuilder.getI32Type(), args&: value, args&: multiplier, args&: shift,
1517 args&: roundingMode);
1518
1519 // Move to the new zero-point.
1520 value =
1521 nestedBuilder.create<arith::AddIOp>(location: nestedLoc, args&: value, args&: outputZp);
1522
1523 // Saturate to the output size.
1524 int32_t intMin = APInt::getSignedMinValue(numBits: outBitWidth).getSExtValue();
1525 int32_t intMax = APInt::getSignedMaxValue(numBits: outBitWidth).getSExtValue();
1526
1527 // Unsigned integers have a difference output value.
1528 if (op.getOutputUnsigned()) {
1529 intMin = 0;
1530 intMax = APInt::getMaxValue(numBits: outBitWidth).getZExtValue();
1531 }
1532
1533 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1534 location: loc, args: nestedBuilder.getI32IntegerAttr(value: intMin));
1535 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1536 location: loc, args: nestedBuilder.getI32IntegerAttr(value: intMax));
1537
1538 value = clampIntHelper(loc: nestedLoc, arg: value, min: intMinVal, max: intMaxVal,
1539 rewriter&: nestedBuilder, /*isUnsigned=*/false);
1540
1541 if (outIntType.getWidth() < 32) {
1542 value = nestedBuilder.create<arith::TruncIOp>(
1543 location: nestedLoc, args: rewriter.getIntegerType(width: outIntType.getWidth()),
1544 args&: value);
1545 }
1546
1547 if (outIntType.isUnsignedInteger()) {
1548 value = nestedBuilder
1549 .create<UnrealizedConversionCastOp>(location: nestedLoc,
1550 args&: outIntType, args&: value)
1551 .getResult(i: 0);
1552 }
1553 nestedBuilder.create<linalg::YieldOp>(location: loc, args&: value);
1554 });
1555
1556 rewriter.replaceOp(op, newValues: linalgOp->getResults());
1557 return success();
1558 }
1559};
1560
1561// Handle the resize case where the input is a 1x1 image. This case
1562// can entirely avoiding having extract operations which target much
1563// more difficult to optimize away.
1564class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1565public:
1566 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1567
1568 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1569 PatternRewriter &rewriter) const final {
1570 Location loc = op.getLoc();
1571 ImplicitLocOpBuilder builder(loc, rewriter);
1572 auto input = op.getInput();
1573 auto inputTy = cast<RankedTensorType>(Val: input.getType());
1574 auto resultTy = cast<RankedTensorType>(Val: op.getType());
1575 const bool isBilinear = op.getMode() == "BILINEAR";
1576
1577 auto inputH = inputTy.getDimSize(idx: 1);
1578 auto inputW = inputTy.getDimSize(idx: 2);
1579 auto outputH = resultTy.getDimSize(idx: 1);
1580 auto outputW = resultTy.getDimSize(idx: 2);
1581
1582 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1583 return rewriter.notifyMatchFailure(
1584 arg&: op, msg: "tosa.resize is not a pure 1x1->1x1 image operation");
1585
1586 // TODO(suderman): These string values should be declared the TOSA dialect.
1587 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1588 return rewriter.notifyMatchFailure(
1589 arg&: op, msg: "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1590
1591 if (inputTy == resultTy) {
1592 rewriter.replaceOp(op, newValues: input);
1593 return success();
1594 }
1595
1596 SmallVector<int64_t> scale;
1597 if (!tosa::getConstShapeValues(op: op.getScale().getDefiningOp(), result_shape&: scale)) {
1598 return failure();
1599 }
1600
1601 // Collapse the unit width and height away.
1602 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1603 reassociationMap[0].push_back(Elt: builder.getAffineDimExpr(position: 0));
1604 reassociationMap[1].push_back(Elt: builder.getAffineDimExpr(position: 1));
1605 reassociationMap[1].push_back(Elt: builder.getAffineDimExpr(position: 2));
1606 reassociationMap[1].push_back(Elt: builder.getAffineDimExpr(position: 3));
1607
1608 auto collapseTy =
1609 RankedTensorType::get(shape: {inputTy.getDimSize(idx: 0), inputTy.getDimSize(idx: 3)},
1610 elementType: inputTy.getElementType());
1611 Value collapse = builder.create<tensor::CollapseShapeOp>(args&: collapseTy, args&: input,
1612 args&: reassociationMap);
1613
1614 // Get any dynamic shapes that appear in the input format.
1615 llvm::SmallVector<Value> outputDynSize;
1616 if (inputTy.isDynamicDim(idx: 0))
1617 outputDynSize.push_back(Elt: builder.create<tensor::DimOp>(args&: input, args: 0));
1618 if (inputTy.isDynamicDim(idx: 3))
1619 outputDynSize.push_back(Elt: builder.create<tensor::DimOp>(args&: input, args: 3));
1620
1621 // Generate the elementwise operation for casting scaling the input value.
1622 auto genericTy = collapseTy.clone(elementType: resultTy.getElementType());
1623 Value empty = builder.create<tensor::EmptyOp>(
1624 args: genericTy.getShape(), args: resultTy.getElementType(), args&: outputDynSize);
1625 auto genericMap = rewriter.getMultiDimIdentityMap(rank: genericTy.getRank());
1626 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1627 utils::IteratorType::parallel);
1628
1629 auto generic = builder.create<linalg::GenericOp>(
1630 args&: genericTy, args: ValueRange{collapse}, args: ValueRange{empty},
1631 args: ArrayRef<AffineMap>{genericMap, genericMap}, args&: iterators,
1632 args: [=](OpBuilder &b, Location loc, ValueRange args) {
1633 Value value = args[0];
1634 // This is the quantized case.
1635 if (inputTy.getElementType() != resultTy.getElementType()) {
1636 value =
1637 b.create<arith::ExtSIOp>(location: loc, args: resultTy.getElementType(), args&: value);
1638
1639 if (isBilinear && scale[0] != 0) {
1640 Value scaleY = b.create<arith::ConstantOp>(
1641 location: loc, args: b.getI32IntegerAttr(value: scale[0]));
1642 value = b.create<arith::MulIOp>(location: loc, args&: value, args&: scaleY);
1643 }
1644
1645 if (isBilinear && scale[2] != 0) {
1646 Value scaleX = b.create<arith::ConstantOp>(
1647 location: loc, args: b.getI32IntegerAttr(value: scale[2]));
1648 value = b.create<arith::MulIOp>(location: loc, args&: value, args&: scaleX);
1649 }
1650 }
1651
1652 b.create<linalg::YieldOp>(location: loc, args&: value);
1653 });
1654
1655 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1656 op, args&: resultTy, args: generic.getResults()[0], args&: reassociationMap);
1657 return success();
1658 }
1659};
1660
1661// TOSA resize with width or height of 1 may be broadcasted to a wider
1662// dimension. This is done by materializing a new tosa.resize without
1663// the broadcasting behavior, and an explicit broadcast afterwards.
1664class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1665public:
1666 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1667
1668 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1669 PatternRewriter &rewriter) const final {
1670 Location loc = op.getLoc();
1671 ImplicitLocOpBuilder builder(loc, rewriter);
1672 auto input = op.getInput();
1673 auto inputTy = dyn_cast<RankedTensorType>(Val: input.getType());
1674 auto resultTy = dyn_cast<RankedTensorType>(Val: op.getType());
1675
1676 if (!inputTy || !resultTy)
1677 return rewriter.notifyMatchFailure(arg&: op,
1678 msg: "requires ranked input/output types");
1679
1680 auto batch = inputTy.getDimSize(idx: 0);
1681 auto channels = inputTy.getDimSize(idx: 3);
1682 auto inputH = inputTy.getDimSize(idx: 1);
1683 auto inputW = inputTy.getDimSize(idx: 2);
1684 auto outputH = resultTy.getDimSize(idx: 1);
1685 auto outputW = resultTy.getDimSize(idx: 2);
1686
1687 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1688 return rewriter.notifyMatchFailure(
1689 arg&: op, msg: "tosa.resize has no broadcasting behavior");
1690
1691 // For any dimension that is broadcastable we generate a width of 1
1692 // on the output.
1693 llvm::SmallVector<int64_t> resizeShape;
1694 resizeShape.push_back(Elt: batch);
1695 resizeShape.push_back(Elt: inputH == 1 ? 1 : outputH);
1696 resizeShape.push_back(Elt: inputW == 1 ? 1 : outputW);
1697 resizeShape.push_back(Elt: channels);
1698
1699 auto resizeTy = resultTy.clone(shape: resizeShape);
1700 auto resize = builder.create<tosa::ResizeOp>(args&: resizeTy, args&: input, args: op.getScale(),
1701 args: op.getOffset(), args: op.getBorder(),
1702 args: op.getMode());
1703
1704 // Collapse an unit result dims.
1705 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1706 reassociationMap[0].push_back(Elt: builder.getAffineDimExpr(position: 0));
1707 reassociationMap.back().push_back(Elt: builder.getAffineDimExpr(position: 1));
1708 if (inputH != 1)
1709 reassociationMap.push_back(Elt: {});
1710 reassociationMap.back().push_back(Elt: builder.getAffineDimExpr(position: 2));
1711 if (inputW != 1)
1712 reassociationMap.push_back(Elt: {});
1713 reassociationMap.back().push_back(Elt: builder.getAffineDimExpr(position: 3));
1714
1715 llvm::SmallVector<int64_t> collapseShape = {batch};
1716 if (inputH != 1)
1717 collapseShape.push_back(Elt: outputH);
1718 if (inputW != 1)
1719 collapseShape.push_back(Elt: outputW);
1720 collapseShape.push_back(Elt: channels);
1721
1722 auto collapseTy = resultTy.clone(shape: collapseShape);
1723 Value collapse = builder.create<tensor::CollapseShapeOp>(args&: collapseTy, args&: resize,
1724 args&: reassociationMap);
1725
1726 // Broadcast the collapsed shape to the output result.
1727 llvm::SmallVector<Value> outputDynSize;
1728 if (inputTy.isDynamicDim(idx: 0))
1729 outputDynSize.push_back(Elt: builder.create<tensor::DimOp>(args&: input, args: 0));
1730 if (inputTy.isDynamicDim(idx: 3))
1731 outputDynSize.push_back(Elt: builder.create<tensor::DimOp>(args&: input, args: 3));
1732
1733 SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1734 utils::IteratorType::parallel);
1735 Value empty = builder.create<tensor::EmptyOp>(
1736 args: resultTy.getShape(), args: resultTy.getElementType(), args&: outputDynSize);
1737
1738 SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(position: 0)};
1739 if (inputH != 1)
1740 inputExprs.push_back(Elt: rewriter.getAffineDimExpr(position: 1));
1741 if (inputW != 1)
1742 inputExprs.push_back(Elt: rewriter.getAffineDimExpr(position: 2));
1743 inputExprs.push_back(Elt: rewriter.getAffineDimExpr(position: 3));
1744
1745 auto inputMap = AffineMap::get(dimCount: resultTy.getRank(), /*symbolCount=*/0,
1746 results: inputExprs, context: rewriter.getContext());
1747
1748 auto outputMap = rewriter.getMultiDimIdentityMap(rank: resultTy.getRank());
1749 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1750 op, args&: resultTy, args: ValueRange{collapse}, args: ValueRange{empty},
1751 args: ArrayRef<AffineMap>{inputMap, outputMap}, args&: iterators,
1752 args: [=](OpBuilder &b, Location loc, ValueRange args) {
1753 Value value = args[0];
1754 b.create<linalg::YieldOp>(location: loc, args&: value);
1755 });
1756
1757 return success();
1758 }
1759};
1760
1761class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1762public:
1763 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1764
1765 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1766 PatternRewriter &rewriter) const final {
1767 Location loc = op.getLoc();
1768 ImplicitLocOpBuilder b(loc, rewriter);
1769 auto input = op.getInput();
1770 auto inputTy = cast<ShapedType>(Val: input.getType());
1771 auto resultTy = cast<ShapedType>(Val: op.getType());
1772 auto resultETy = resultTy.getElementType();
1773
1774 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1775 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1776
1777 auto imageH = inputTy.getShape()[1];
1778 auto imageW = inputTy.getShape()[2];
1779
1780 auto dynamicDimsOr =
1781 checkHasDynamicBatchDims(rewriter, op, params: {input, op.getOutput()});
1782 if (!dynamicDimsOr.has_value())
1783 return rewriter.notifyMatchFailure(
1784 arg&: op, msg: "unable to get dynamic dimensions of tosa.resize");
1785
1786 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1787 return rewriter.notifyMatchFailure(
1788 arg&: op, msg: "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1789
1790 SmallVector<AffineMap, 2> affineMaps = {
1791 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())};
1792 auto emptyTensor = b.create<tensor::EmptyOp>(args: resultTy.getShape(), args&: resultETy,
1793 args&: *dynamicDimsOr);
1794 auto genericOp = b.create<linalg::GenericOp>(
1795 args&: resultTy, args: ValueRange({}), args: ValueRange{emptyTensor}, args&: affineMaps,
1796 args: getNParallelLoopsAttrs(nParallelLoops: resultTy.getRank()));
1797 Value resize = genericOp.getResult(i: 0);
1798
1799 {
1800 OpBuilder::InsertionGuard regionGuard(b);
1801 b.createBlock(parent: &genericOp.getRegion(), insertPt: genericOp.getRegion().end(),
1802 argTypes: TypeRange({resultETy}), locs: loc);
1803 Value batch = b.create<linalg::IndexOp>(args: 0);
1804 Value y = b.create<linalg::IndexOp>(args: 1);
1805 Value x = b.create<linalg::IndexOp>(args: 2);
1806 Value channel = b.create<linalg::IndexOp>(args: 3);
1807
1808 Value zeroI32 =
1809 b.create<arith::ConstantOp>(args: b.getZeroAttr(type: b.getI32Type()));
1810 Value zeroFp = b.create<arith::ConstantOp>(args: b.getZeroAttr(type: floatTy));
1811 Value hMax = b.create<arith::ConstantOp>(args: b.getI32IntegerAttr(value: imageH - 1));
1812 Value wMax = b.create<arith::ConstantOp>(args: b.getI32IntegerAttr(value: imageW - 1));
1813
1814 Value inY = b.create<arith::IndexCastOp>(args: b.getI32Type(), args&: y);
1815 Value inX = b.create<arith::IndexCastOp>(args: b.getI32Type(), args&: x);
1816
1817 SmallVector<int64_t> scale, offset, border;
1818 if (!tosa::getConstShapeValues(op: op.getScale().getDefiningOp(), result_shape&: scale) ||
1819 !tosa::getConstShapeValues(op: op.getOffset().getDefiningOp(), result_shape&: offset) ||
1820 !tosa::getConstShapeValues(op: op.getBorder().getDefiningOp(), result_shape&: border)) {
1821 return rewriter.notifyMatchFailure(
1822 arg&: op, msg: "tosa.resize scale/offset/border should have compile time "
1823 "constant values.");
1824 }
1825
1826 Value yScaleN, yScaleD, xScaleN, xScaleD;
1827 yScaleN = b.create<arith::ConstantOp>(args: b.getI32IntegerAttr(value: scale[0]));
1828 yScaleD = b.create<arith::ConstantOp>(args: b.getI32IntegerAttr(value: scale[1]));
1829 xScaleN = b.create<arith::ConstantOp>(args: b.getI32IntegerAttr(value: scale[2]));
1830 xScaleD = b.create<arith::ConstantOp>(args: b.getI32IntegerAttr(value: scale[3]));
1831
1832 Value yOffset, xOffset, yBorder, xBorder;
1833 yOffset = b.create<arith::ConstantOp>(args: b.getI32IntegerAttr(value: offset[0]));
1834 xOffset = b.create<arith::ConstantOp>(args: b.getI32IntegerAttr(value: offset[1]));
1835 yBorder = b.create<arith::ConstantOp>(args: b.getI32IntegerAttr(value: border[0]));
1836 xBorder = b.create<arith::ConstantOp>(args: b.getI32IntegerAttr(value: border[1]));
1837
1838 // Compute the ix and dx values for both the X and Y dimensions.
1839 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1840 Value scaleN, Value scaleD, Value offset,
1841 int size, ImplicitLocOpBuilder &b) {
1842 if (size == 1) {
1843 index = zeroI32;
1844 delta = zeroFp;
1845 return;
1846 }
1847 // x = x * scale_d + offset;
1848 // ix = floor(x / scale_n)
1849 Value val = b.create<arith::MulIOp>(args&: in, args&: scaleD);
1850 val = b.create<arith::AddIOp>(args&: val, args&: offset);
1851 index = b.create<arith::FloorDivSIOp>(args&: val, args&: scaleN);
1852
1853 // rx = x % scale_n
1854 // dx = rx / scale_n
1855 Value r = b.create<arith::RemSIOp>(args&: val, args&: scaleN);
1856 Value rFp = b.create<arith::SIToFPOp>(args&: floatTy, args&: r);
1857 Value scaleNfp = b.create<arith::UIToFPOp>(args&: floatTy, args&: scaleN);
1858 delta = b.create<arith::DivFOp>(args&: rFp, args&: scaleNfp);
1859 };
1860
1861 // Compute the ix and dx values for the X and Y dimensions - int case.
1862 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1863 Value scaleN, Value scaleD, Value offset,
1864 int size, ImplicitLocOpBuilder &b) {
1865 if (size == 1) {
1866 index = zeroI32;
1867 delta = zeroI32;
1868 return;
1869 }
1870 // x = x * scale_d + offset;
1871 // ix = floor(x / scale_n)
1872 // dx = x - ix * scale_n;
1873 Value val = b.create<arith::MulIOp>(args&: in, args&: scaleD);
1874 val = b.create<arith::AddIOp>(args&: val, args&: offset);
1875 index = b.create<arith::DivSIOp>(args&: val, args&: scaleN);
1876 delta = b.create<arith::MulIOp>(args&: index, args&: scaleN);
1877 delta = b.create<arith::SubIOp>(args&: val, args&: delta);
1878 };
1879
1880 Value ix, iy, dx, dy;
1881 if (floatingPointMode) {
1882 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1883 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1884 } else {
1885 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1886 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1887 }
1888
1889 if (op.getMode() == "NEAREST_NEIGHBOR") {
1890 auto one = b.create<arith::ConstantOp>(args: b.getI32IntegerAttr(value: 1));
1891
1892 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1893 Value max, int size,
1894 ImplicitLocOpBuilder &b) -> Value {
1895 if (size == 1) {
1896 return b.create<arith::ConstantIndexOp>(args: 0);
1897 }
1898
1899 Value pred;
1900 if (floatingPointMode) {
1901 auto h = b.create<arith::ConstantOp>(args: b.getFloatAttr(type: floatTy, value: 0.5f));
1902 pred = b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGE, args&: dval, args&: h);
1903 } else {
1904 Value dvalDouble = b.create<arith::ShLIOp>(args&: dval, args&: one);
1905 pred = b.create<arith::CmpIOp>(args: arith::CmpIPredicate::sge,
1906 args&: dvalDouble, args&: scale);
1907 }
1908
1909 auto offset = b.create<arith::SelectOp>(args&: pred, args&: one, args&: zeroI32);
1910 val = b.create<arith::AddIOp>(args&: val, args&: offset);
1911 val = clampIntHelper(loc, arg: val, min: zeroI32, max, rewriter&: b, /*isUnsigned=*/false);
1912 return b.create<arith::IndexCastOp>(args: b.getIndexType(), args&: val);
1913 };
1914
1915 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1916 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1917
1918 Value result = b.create<tensor::ExtractOp>(
1919 args&: input, args: ValueRange{batch, iy, ix, channel});
1920
1921 b.create<linalg::YieldOp>(args&: result);
1922 } else {
1923 // The mode here must be BILINEAR.
1924 assert(op.getMode() == "BILINEAR");
1925
1926 auto oneVal = b.create<arith::ConstantOp>(args: b.getI32IntegerAttr(value: 1));
1927
1928 auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1929 Value max, ImplicitLocOpBuilder &b) {
1930 val0 = in;
1931 val1 = b.create<arith::AddIOp>(args&: val0, args&: oneVal);
1932 val0 =
1933 clampIntHelper(loc, arg: val0, min: zeroI32, max, rewriter&: b, /*isUnsigned=*/false);
1934 val1 =
1935 clampIntHelper(loc, arg: val1, min: zeroI32, max, rewriter&: b, /*isUnsigned=*/false);
1936 val0 = b.create<arith::IndexCastOp>(args: b.getIndexType(), args&: val0);
1937 val1 = b.create<arith::IndexCastOp>(args: b.getIndexType(), args&: val1);
1938 };
1939
1940 // Linalg equivalent to the section below:
1941 // int16_t iy0 = apply_max(iy, 0);
1942 // int16_t iy1 = apply_min(iy + 1, IH - 1);
1943 // int16_t ix0 = apply_max(ix, 0);
1944 // int16_t ix1 = apply_min(ix + 1, IW - 1);
1945 Value x0, x1, y0, y1;
1946 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1947 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1948
1949 Value y0x0 = b.create<tensor::ExtractOp>(
1950 args&: input, args: ValueRange{batch, y0, x0, channel});
1951 Value y0x1 = b.create<tensor::ExtractOp>(
1952 args&: input, args: ValueRange{batch, y0, x1, channel});
1953 Value y1x0 = b.create<tensor::ExtractOp>(
1954 args&: input, args: ValueRange{batch, y1, x0, channel});
1955 Value y1x1 = b.create<tensor::ExtractOp>(
1956 args&: input, args: ValueRange{batch, y1, x1, channel});
1957
1958 if (floatingPointMode) {
1959 auto oneVal =
1960 b.create<arith::ConstantOp>(args: b.getFloatAttr(type: floatTy, value: 1.0f));
1961 auto interpolate = [&](Value val0, Value val1, Value delta,
1962 int inputSize,
1963 ImplicitLocOpBuilder &b) -> Value {
1964 if (inputSize == 1)
1965 return val0;
1966 Value oneMinusDelta = b.create<arith::SubFOp>(args&: oneVal, args&: delta);
1967 Value mul0 = b.create<arith::MulFOp>(args&: val0, args&: oneMinusDelta);
1968 Value mul1 = b.create<arith::MulFOp>(args&: val1, args&: delta);
1969 return b.create<arith::AddFOp>(args&: mul0, args&: mul1);
1970 };
1971
1972 // Linalg equivalent to the section below:
1973 // topAcc = v00 * (unit_x - dx);
1974 // topAcc += v01 * dx;
1975 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1976
1977 // Linalg equivalent to the section below:
1978 // bottomAcc = v10 * (unit_x - dx);
1979 // bottomAcc += v11 * dx;
1980 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1981
1982 // Linalg equivalent to the section below:
1983 // result = topAcc * (unit_y - dy) + bottomAcc * dy
1984 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1985 b.create<linalg::YieldOp>(args&: result);
1986 } else {
1987 // Perform in quantized space.
1988 y0x0 = b.create<arith::ExtSIOp>(args&: resultETy, args&: y0x0);
1989 y0x1 = b.create<arith::ExtSIOp>(args&: resultETy, args&: y0x1);
1990 y1x0 = b.create<arith::ExtSIOp>(args&: resultETy, args&: y1x0);
1991 y1x1 = b.create<arith::ExtSIOp>(args&: resultETy, args&: y1x1);
1992
1993 const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1994 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1995 dx = b.create<arith::ExtSIOp>(args&: resultETy, args&: dx);
1996 dy = b.create<arith::ExtSIOp>(args&: resultETy, args&: dy);
1997 }
1998
1999 Value yScaleNExt = yScaleN;
2000 Value xScaleNExt = xScaleN;
2001
2002 const int64_t scaleBitwidth =
2003 xScaleN.getType().getIntOrFloatBitWidth();
2004 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2005 yScaleNExt = b.create<arith::ExtSIOp>(args&: resultETy, args&: yScaleN);
2006 xScaleNExt = b.create<arith::ExtSIOp>(args&: resultETy, args&: xScaleN);
2007 }
2008
2009 auto interpolate = [](Value val0, Value val1, Value weight1,
2010 Value scale, int inputSize,
2011 ImplicitLocOpBuilder &b) -> Value {
2012 if (inputSize == 1)
2013 return b.create<arith::MulIOp>(args&: val0, args&: scale);
2014 Value weight0 = b.create<arith::SubIOp>(args&: scale, args&: weight1);
2015 Value mul0 = b.create<arith::MulIOp>(args&: val0, args&: weight0);
2016 Value mul1 = b.create<arith::MulIOp>(args&: val1, args&: weight1);
2017 return b.create<arith::AddIOp>(args&: mul0, args&: mul1);
2018 };
2019
2020 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2021 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2022 Value result =
2023 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2024 b.create<linalg::YieldOp>(args&: result);
2025 }
2026 }
2027 }
2028
2029 rewriter.replaceOp(op, newValues: resize);
2030 return success();
2031 }
2032};
2033
2034// At the codegen level any identity operations should be removed. Any cases
2035// where identity is load-bearing (e.g. cross device computation) should be
2036// handled before lowering to codegen.
2037template <typename SrcOp>
2038class IdentityNConverter : public OpRewritePattern<SrcOp> {
2039public:
2040 using OpRewritePattern<SrcOp>::OpRewritePattern;
2041
2042 LogicalResult matchAndRewrite(SrcOp op,
2043 PatternRewriter &rewriter) const final {
2044 rewriter.replaceOp(op, op.getOperation()->getOperands());
2045 return success();
2046 }
2047};
2048
2049template <typename SrcOp>
2050class ReduceConverter : public OpRewritePattern<SrcOp> {
2051public:
2052 using OpRewritePattern<SrcOp>::OpRewritePattern;
2053
2054 LogicalResult matchAndRewrite(SrcOp reduceOp,
2055 PatternRewriter &rewriter) const final {
2056 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
2057 }
2058};
2059
2060class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2061public:
2062 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2063
2064 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2065 PatternRewriter &rewriter) const final {
2066 auto loc = op.getLoc();
2067 Value input = op.getInput1();
2068 auto inputTy = cast<ShapedType>(Val: input.getType());
2069 auto resultTy = cast<ShapedType>(Val: op.getType());
2070 auto axis = op.getAxis();
2071
2072 SmallVector<Value> dynDims;
2073 for (int i = 0; i < inputTy.getRank(); i++) {
2074 if (inputTy.isDynamicDim(idx: i)) {
2075 dynDims.push_back(Elt: rewriter.create<tensor::DimOp>(location: loc, args&: input, args&: i));
2076 }
2077 }
2078
2079 Value axisDimSize = rewriter.create<tensor::DimOp>(location: loc, args&: input, args&: axis);
2080
2081 // First fill the output buffer with the init value.
2082 auto emptyTensor = rewriter
2083 .create<tensor::EmptyOp>(location: loc, args: inputTy.getShape(),
2084 args: inputTy.getElementType(),
2085 args: ArrayRef<Value>({dynDims}))
2086 .getResult();
2087 SmallVector<AffineMap, 2> affineMaps = {
2088 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())};
2089
2090 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2091 op, args&: resultTy, args: ArrayRef<Value>({}), args: ValueRange{emptyTensor}, args&: affineMaps,
2092 args: getNParallelLoopsAttrs(nParallelLoops: resultTy.getRank()),
2093 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2094 llvm::SmallVector<Value> indices;
2095 for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2096 Value index =
2097 rewriter.create<linalg::IndexOp>(location: nestedLoc, args&: i).getResult();
2098 if (i == axis) {
2099 auto one = rewriter.create<arith::ConstantIndexOp>(location: nestedLoc, args: 1);
2100 auto sizeMinusOne =
2101 rewriter.create<arith::SubIOp>(location: nestedLoc, args&: axisDimSize, args&: one);
2102 index = rewriter.create<arith::SubIOp>(location: nestedLoc, args&: sizeMinusOne,
2103 args&: index);
2104 }
2105
2106 indices.push_back(Elt: index);
2107 }
2108
2109 auto extract = nestedBuilder.create<tensor::ExtractOp>(
2110 location: nestedLoc, args&: input, args&: indices);
2111 nestedBuilder.create<linalg::YieldOp>(location: op.getLoc(),
2112 args: extract.getResult());
2113 });
2114 return success();
2115 }
2116};
2117
2118// This converter translate a tile operation to a reshape, broadcast, reshape.
2119// The first reshape minimally expands each tiled dimension to include a
2120// proceding size-1 dim. This dim is then broadcasted to the appropriate
2121// multiple.
2122struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2123 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2124
2125 LogicalResult
2126 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2127 ConversionPatternRewriter &rewriter) const override {
2128 auto loc = op.getLoc();
2129 auto input = op.getInput1();
2130 auto inputTy = cast<ShapedType>(Val: input.getType());
2131 auto inputShape = inputTy.getShape();
2132 auto resultTy = cast<ShapedType>(Val: op.getType());
2133 auto elementTy = inputTy.getElementType();
2134 int64_t rank = inputTy.getRank();
2135
2136 SmallVector<int64_t> multiples;
2137 if (failed(Result: op.getConstantMultiples(multiples)))
2138 return failure();
2139
2140 // Broadcast the newly added dimensions to their appropriate multiple.
2141 SmallVector<int64_t, 2> genericShape;
2142 for (int i = 0; i < rank; i++) {
2143 int64_t dim = multiples[i];
2144 genericShape.push_back(Elt: dim == -1 ? ShapedType::kDynamic : dim);
2145 genericShape.push_back(Elt: inputShape[i]);
2146 }
2147
2148 SmallVector<Value> dynDims;
2149 for (int i = 0; i < inputTy.getRank(); i++) {
2150 if (inputTy.isDynamicDim(idx: i) || multiples[i] == -1) {
2151 dynDims.push_back(Elt: rewriter.create<tensor::DimOp>(location: loc, args&: input, args&: i));
2152 }
2153 }
2154
2155 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
2156 location: op.getLoc(), args&: genericShape, args&: elementTy, args&: dynDims);
2157
2158 // We needs to map the input shape to the non-broadcasted dimensions.
2159 SmallVector<AffineExpr, 4> dimExprs;
2160 dimExprs.reserve(N: rank);
2161 for (unsigned i = 0; i < rank; ++i)
2162 dimExprs.push_back(Elt: rewriter.getAffineDimExpr(position: i * 2 + 1));
2163
2164 auto readAffineMap =
2165 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, results: dimExprs,
2166 context: rewriter.getContext());
2167
2168 SmallVector<AffineMap, 2> affineMaps = {
2169 readAffineMap, rewriter.getMultiDimIdentityMap(rank: genericShape.size())};
2170
2171 auto genericOp = rewriter.create<linalg::GenericOp>(
2172 location: loc, args: RankedTensorType::get(shape: genericShape, elementType: elementTy), args&: input,
2173 args: ValueRange{emptyTensor}, args&: affineMaps,
2174 args: getNParallelLoopsAttrs(nParallelLoops: genericShape.size()),
2175 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2176 nestedBuilder.create<linalg::YieldOp>(location: op.getLoc(), args: *args.begin());
2177 });
2178
2179 auto shapeValue = getTosaConstShape(
2180 rewriter, loc, shape: mlir::tosa::convertFromMlirShape(shape: resultTy.getShape()));
2181 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2182 op, args&: resultTy, args: genericOp.getResult(i: 0), args&: shapeValue);
2183 return success();
2184 }
2185};
2186
2187// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2188// op, producing two output buffers.
2189//
2190// The first output buffer contains the index of the found maximum value. It is
2191// initialized to 0 and is resulting integer type.
2192//
2193// The second output buffer contains the maximum value found. It is initialized
2194// to the minimum representable value of the input element type. After being
2195// populated by indexed_generic, this buffer is disgarded as only the index is
2196// requested.
2197//
2198// The indexed_generic op updates both the maximum value and index if the
2199// current value exceeds the running max.
2200class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2201public:
2202 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2203
2204 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2205 PatternRewriter &rewriter) const final {
2206 auto loc = argmaxOp.getLoc();
2207 Value input = argmaxOp.getInput();
2208 auto inputTy = cast<ShapedType>(Val: input.getType());
2209 auto resultTy = cast<ShapedType>(Val: argmaxOp.getOutput().getType());
2210 auto inElementTy = inputTy.getElementType();
2211 auto outElementTy = resultTy.getElementType();
2212 int axis = argmaxOp.getAxis();
2213 auto resultMaxTy = RankedTensorType::get(shape: resultTy.getShape(), elementType: inElementTy);
2214
2215 if (!isa<IntegerType>(Val: outElementTy))
2216 return rewriter.notifyMatchFailure(
2217 arg&: argmaxOp,
2218 msg: "tosa.arg_max to linalg.* requires integer-like result type");
2219
2220 SmallVector<Value> dynDims;
2221 for (int i = 0; i < inputTy.getRank(); i++) {
2222 if (inputTy.isDynamicDim(idx: i) && i != axis) {
2223 dynDims.push_back(Elt: rewriter.create<tensor::DimOp>(location: loc, args&: input, args&: i));
2224 }
2225 }
2226
2227 // First fill the output buffer for the index.
2228 auto emptyTensorIdx = rewriter
2229 .create<tensor::EmptyOp>(location: loc, args: resultTy.getShape(),
2230 args&: outElementTy, args&: dynDims)
2231 .getResult();
2232 auto fillValueIdx = rewriter.create<arith::ConstantOp>(
2233 location: loc, args: rewriter.getIntegerAttr(type: outElementTy, value: 0));
2234 auto filledTensorIdx =
2235 rewriter
2236 .create<linalg::FillOp>(location: loc, args: ValueRange{fillValueIdx},
2237 args: ValueRange{emptyTensorIdx})
2238 .result();
2239
2240 // Second fill the output buffer for the running max.
2241 auto emptyTensorMax = rewriter
2242 .create<tensor::EmptyOp>(location: loc, args: resultTy.getShape(),
2243 args&: inElementTy, args&: dynDims)
2244 .getResult();
2245 auto fillValueMaxAttr =
2246 createInitialValueForReduceOp(op: argmaxOp, elementTy: inElementTy, rewriter);
2247
2248 if (!fillValueMaxAttr)
2249 return rewriter.notifyMatchFailure(
2250 arg&: argmaxOp, msg: "unsupported tosa.argmax element type");
2251
2252 auto fillValueMax =
2253 rewriter.create<arith::ConstantOp>(location: loc, args&: fillValueMaxAttr);
2254 auto filledTensorMax =
2255 rewriter
2256 .create<linalg::FillOp>(location: loc, args: ValueRange{fillValueMax},
2257 args: ValueRange{emptyTensorMax})
2258 .result();
2259
2260 // We need to reduce along the arg-max axis, with parallel operations along
2261 // the rest.
2262 SmallVector<utils::IteratorType, 4> iteratorTypes;
2263 iteratorTypes.resize(N: inputTy.getRank(), NV: utils::IteratorType::parallel);
2264 iteratorTypes[axis] = utils::IteratorType::reduction;
2265
2266 SmallVector<AffineExpr, 2> srcExprs;
2267 SmallVector<AffineExpr, 2> dstExprs;
2268 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2269 srcExprs.push_back(Elt: mlir::getAffineDimExpr(position: i, context: rewriter.getContext()));
2270 if (axis != i)
2271 dstExprs.push_back(Elt: mlir::getAffineDimExpr(position: i, context: rewriter.getContext()));
2272 }
2273
2274 bool didEncounterError = false;
2275 auto maps = AffineMap::inferFromExprList(exprsList: {srcExprs, dstExprs, dstExprs},
2276 context: rewriter.getContext());
2277 auto linalgOp = rewriter.create<linalg::GenericOp>(
2278 location: loc, args: ArrayRef<Type>({resultTy, resultMaxTy}), args&: input,
2279 args: ValueRange({filledTensorIdx, filledTensorMax}), args&: maps, args&: iteratorTypes,
2280 args: [&](OpBuilder &nestedBuilder, Location nestedLoc,
2281 ValueRange blockArgs) {
2282 auto newValue = blockArgs[0];
2283 auto oldIndex = blockArgs[1];
2284 auto oldValue = blockArgs[2];
2285
2286 Value newIndex = rewriter.create<arith::IndexCastOp>(
2287 location: nestedLoc, args: oldIndex.getType(),
2288 args: rewriter.create<linalg::IndexOp>(location: loc, args&: axis));
2289
2290 Value predicate;
2291 if (isa<FloatType>(Val: inElementTy)) {
2292 if (argmaxOp.getNanMode() == "IGNORE") {
2293 // Only update index & max value for non NaN values. If all
2294 // values are NaNs, the initial index will be return which is 0.
2295 predicate = rewriter.create<arith::CmpFOp>(
2296 location: nestedLoc, args: arith::CmpFPredicate::OGT, args&: newValue, args&: oldValue);
2297 } else {
2298 // Update max value if either of the following is true:
2299 // - new value is bigger
2300 // - cur max is not NaN and new value is NaN
2301 Value gt = rewriter.create<arith::CmpFOp>(
2302 location: nestedLoc, args: arith::CmpFPredicate::UGT, args&: newValue, args&: oldValue);
2303 Value oldNonNaN = rewriter.create<arith::CmpFOp>(
2304 location: nestedLoc, args: arith::CmpFPredicate::ORD, args&: oldValue, args&: oldValue);
2305 predicate = rewriter.create<arith::AndIOp>(
2306 location: nestedLoc, args: rewriter.getI1Type(), args&: gt, args&: oldNonNaN);
2307 }
2308 } else if (isa<IntegerType>(Val: inElementTy)) {
2309 predicate = rewriter.create<arith::CmpIOp>(
2310 location: nestedLoc, args: arith::CmpIPredicate::sgt, args&: newValue, args&: oldValue);
2311 } else {
2312 didEncounterError = true;
2313 return;
2314 }
2315
2316 auto resultMax = rewriter.create<arith::SelectOp>(
2317 location: nestedLoc, args&: predicate, args&: newValue, args&: oldValue);
2318 auto resultIndex = rewriter.create<arith::SelectOp>(
2319 location: nestedLoc, args&: predicate, args&: newIndex, args&: oldIndex);
2320 nestedBuilder.create<linalg::YieldOp>(
2321 location: nestedLoc, args: ValueRange({resultIndex, resultMax}));
2322 });
2323
2324 if (didEncounterError)
2325 return rewriter.notifyMatchFailure(
2326 arg&: argmaxOp, msg: "unsupported tosa.argmax element type");
2327
2328 rewriter.replaceOp(op: argmaxOp, newValues: linalgOp.getResult(i: 0));
2329 return success();
2330 }
2331};
2332
2333class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2334public:
2335 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2336 LogicalResult
2337 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2338 ConversionPatternRewriter &rewriter) const final {
2339 auto input = adaptor.getOperands()[0];
2340 auto indices = adaptor.getOperands()[1];
2341
2342 auto valuesTy = dyn_cast<RankedTensorType>(Val: op.getValues().getType());
2343 auto resultTy = dyn_cast<RankedTensorType>(Val: op.getType());
2344 if (!valuesTy || !resultTy)
2345 return rewriter.notifyMatchFailure(arg&: op, msg: "unranked tensors not supported");
2346
2347 auto dynamicDims = inferDynamicDimsForGather(
2348 builder&: rewriter, loc: op.getLoc(), values: adaptor.getValues(), indices: adaptor.getIndices());
2349
2350 auto resultElementTy = resultTy.getElementType();
2351
2352 auto loc = op.getLoc();
2353 auto emptyTensor =
2354 rewriter
2355 .create<tensor::EmptyOp>(location: loc, args: resultTy.getShape(), args&: resultElementTy,
2356 args&: dynamicDims)
2357 .getResult();
2358
2359 SmallVector<AffineMap, 2> affineMaps = {
2360 AffineMap::get(
2361 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2362 results: {rewriter.getAffineDimExpr(position: 0), rewriter.getAffineDimExpr(position: 1)},
2363 context: rewriter.getContext()),
2364 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())};
2365
2366 auto genericOp = rewriter.create<linalg::GenericOp>(
2367 location: loc, args: ArrayRef<Type>({resultTy}), args: ValueRange{indices},
2368 args: ValueRange{emptyTensor}, args&: affineMaps,
2369 args: getNParallelLoopsAttrs(nParallelLoops: resultTy.getRank()),
2370 args: [&](OpBuilder &b, Location loc, ValueRange args) {
2371 auto indexValue = args[0];
2372 auto index0 = rewriter.create<linalg::IndexOp>(location: loc, args: 0);
2373 Value index1 = rewriter.create<arith::IndexCastOp>(
2374 location: loc, args: rewriter.getIndexType(), args&: indexValue);
2375 auto index2 = rewriter.create<linalg::IndexOp>(location: loc, args: 2);
2376 Value extract = rewriter.create<tensor::ExtractOp>(
2377 location: loc, args&: input, args: ValueRange{index0, index1, index2});
2378 rewriter.create<linalg::YieldOp>(location: loc, args&: extract);
2379 });
2380 rewriter.replaceOp(op, newValues: genericOp.getResult(i: 0));
2381 return success();
2382 }
2383
2384 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2385 Location loc,
2386 Value values,
2387 Value indices) {
2388 llvm::SmallVector<Value> results;
2389
2390 auto addDynamicDimension = [&](Value source, int64_t dim) {
2391 auto sz = tensor::getMixedSize(builder, loc, value: source, dim);
2392 if (auto dimValue = llvm::dyn_cast_if_present<Value>(Val&: sz))
2393 results.push_back(Elt: dimValue);
2394 };
2395
2396 addDynamicDimension(values, 0);
2397 addDynamicDimension(indices, 1);
2398 addDynamicDimension(values, 2);
2399 return results;
2400 }
2401};
2402
2403// Lowerings the TableOp to a series of gathers and numerica operations. This
2404// includes interpolation between the high/low values. For the I8 varient, this
2405// simplifies to a single gather operation.
2406class TableConverter : public OpRewritePattern<tosa::TableOp> {
2407public:
2408 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2409
2410 LogicalResult matchAndRewrite(tosa::TableOp op,
2411 PatternRewriter &rewriter) const final {
2412 auto loc = op.getLoc();
2413 Value input = op.getInput1();
2414 Value table = op.getTable();
2415 auto inputTy = cast<ShapedType>(Val: input.getType());
2416 auto tableTy = cast<ShapedType>(Val: table.getType());
2417 auto resultTy = cast<ShapedType>(Val: op.getType());
2418
2419 auto inputElementTy = inputTy.getElementType();
2420 auto tableElementTy = tableTy.getElementType();
2421 auto resultElementTy = resultTy.getElementType();
2422
2423 SmallVector<Value> dynDims;
2424 for (int i = 0; i < resultTy.getRank(); ++i) {
2425 if (inputTy.isDynamicDim(idx: i)) {
2426 dynDims.push_back(
2427 Elt: rewriter.create<tensor::DimOp>(location: loc, args: op.getOperand(i: 0), args&: i));
2428 }
2429 }
2430
2431 auto emptyTensor = rewriter
2432 .create<tensor::EmptyOp>(location: loc, args: resultTy.getShape(),
2433 args&: resultElementTy, args&: dynDims)
2434 .getResult();
2435
2436 SmallVector<AffineMap, 2> affineMaps = {
2437 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank()),
2438 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())};
2439
2440 auto genericOp = rewriter.create<linalg::GenericOp>(
2441 location: loc, args&: resultTy, args: ValueRange({input}), args: ValueRange{emptyTensor}, args&: affineMaps,
2442 args: getNParallelLoopsAttrs(nParallelLoops: resultTy.getRank()));
2443 rewriter.replaceOp(op, newValues: genericOp.getResult(i: 0));
2444
2445 {
2446 OpBuilder::InsertionGuard regionGuard(rewriter);
2447 Block *block = rewriter.createBlock(
2448 parent: &genericOp.getRegion(), insertPt: genericOp.getRegion().end(),
2449 argTypes: TypeRange({inputElementTy, resultElementTy}), locs: {loc, loc});
2450
2451 auto inputValue = block->getArgument(i: 0);
2452 rewriter.setInsertionPointToStart(block);
2453 if (inputElementTy.isInteger(width: 8) && tableElementTy.isInteger(width: 8) &&
2454 resultElementTy.isInteger(width: 8)) {
2455 Value index = rewriter.create<arith::IndexCastOp>(
2456 location: loc, args: rewriter.getIndexType(), args&: inputValue);
2457 Value offset = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 128);
2458 index = rewriter.create<arith::AddIOp>(location: loc, args: rewriter.getIndexType(),
2459 args&: index, args&: offset);
2460 Value extract =
2461 rewriter.create<tensor::ExtractOp>(location: loc, args&: table, args: ValueRange{index});
2462 rewriter.create<linalg::YieldOp>(location: loc, args&: extract);
2463 return success();
2464 }
2465
2466 if (inputElementTy.isInteger(width: 16) && tableElementTy.isInteger(width: 16) &&
2467 resultElementTy.isInteger(width: 32)) {
2468 Value extend = rewriter.create<arith::ExtSIOp>(
2469 location: loc, args: rewriter.getI32Type(), args&: inputValue);
2470
2471 auto offset = rewriter.create<arith::ConstantOp>(
2472 location: loc, args: rewriter.getI32IntegerAttr(value: 32768));
2473 auto seven = rewriter.create<arith::ConstantOp>(
2474 location: loc, args: rewriter.getI32IntegerAttr(value: 7));
2475 auto one = rewriter.create<arith::ConstantOp>(
2476 location: loc, args: rewriter.getI32IntegerAttr(value: 1));
2477 auto b1111111 = rewriter.create<arith::ConstantOp>(
2478 location: loc, args: rewriter.getI32IntegerAttr(value: 127));
2479
2480 // Compute the index and fractional part from the input value:
2481 // value = value + 32768
2482 // index = value >> 7;
2483 // fraction = 0x01111111 & value
2484 auto extendAdd = rewriter.create<arith::AddIOp>(location: loc, args&: extend, args&: offset);
2485 Value index = rewriter.create<arith::ShRUIOp>(location: loc, args&: extendAdd, args&: seven);
2486 Value fraction =
2487 rewriter.create<arith::AndIOp>(location: loc, args&: extendAdd, args&: b1111111);
2488
2489 // Extract the base and next values from the table.
2490 // base = (int32_t) table[index];
2491 // next = (int32_t) table[index + 1];
2492 Value indexPlusOne = rewriter.create<arith::AddIOp>(location: loc, args&: index, args&: one);
2493
2494 index = rewriter.create<arith::IndexCastOp>(
2495 location: loc, args: rewriter.getIndexType(), args&: index);
2496 indexPlusOne = rewriter.create<arith::IndexCastOp>(
2497 location: loc, args: rewriter.getIndexType(), args&: indexPlusOne);
2498
2499 Value base =
2500 rewriter.create<tensor::ExtractOp>(location: loc, args&: table, args: ValueRange{index});
2501 Value next = rewriter.create<tensor::ExtractOp>(
2502 location: loc, args&: table, args: ValueRange{indexPlusOne});
2503
2504 base =
2505 rewriter.create<arith::ExtSIOp>(location: loc, args: rewriter.getI32Type(), args&: base);
2506 next =
2507 rewriter.create<arith::ExtSIOp>(location: loc, args: rewriter.getI32Type(), args&: next);
2508
2509 // Use the fractional part to interpolate between the input values:
2510 // result = (base << 7) + (next - base) * fraction
2511 Value baseScaled = rewriter.create<arith::ShLIOp>(location: loc, args&: base, args&: seven);
2512 Value diff = rewriter.create<arith::SubIOp>(location: loc, args&: next, args&: base);
2513 Value diffScaled = rewriter.create<arith::MulIOp>(location: loc, args&: diff, args&: fraction);
2514 Value result =
2515 rewriter.create<arith::AddIOp>(location: loc, args&: baseScaled, args&: diffScaled);
2516
2517 rewriter.create<linalg::YieldOp>(location: loc, args&: result);
2518
2519 return success();
2520 }
2521 }
2522
2523 return rewriter.notifyMatchFailure(
2524 arg&: op, msg: "unable to create body for tosa.table op");
2525 }
2526};
2527
2528struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2529 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2530
2531 static bool isRankedTensor(Type type) { return isa<RankedTensorType>(Val: type); }
2532
2533 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2534 OpFoldResult ofr) {
2535 auto one = builder.create<arith::ConstantIndexOp>(location: loc, args: 1);
2536 auto two = builder.create<arith::ConstantIndexOp>(location: loc, args: 2);
2537
2538 auto value = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr);
2539 auto divBy2 = builder.createOrFold<arith::DivUIOp>(location: loc, args&: value, args&: two);
2540 auto plusOne = builder.createOrFold<arith::AddIOp>(location: loc, args&: divBy2, args&: one);
2541 return getAsOpFoldResult(val: plusOne);
2542 }
2543
2544 static RankedTensorType
2545 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2546 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2547 // Get [N, H, W]
2548 auto dims = tensor::getMixedSizes(builder, loc, value: input);
2549
2550 // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2551 // output tensors.
2552 dims[2] = halfPlusOne(builder, loc, ofr: dims[2]);
2553
2554 llvm::SmallVector<int64_t, 3> staticSizes;
2555 dispatchIndexOpFoldResults(ofrs: dims, dynamicVec&: dynamicSizes, staticVec&: staticSizes);
2556
2557 auto elementType = cast<RankedTensorType>(Val: input.getType()).getElementType();
2558 return RankedTensorType::get(shape: staticSizes, elementType);
2559 }
2560
2561 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2562 RankedTensorType type,
2563 llvm::ArrayRef<Value> dynamicSizes) {
2564 auto emptyTensor =
2565 rewriter.create<tensor::EmptyOp>(location: loc, args&: type, args&: dynamicSizes);
2566 auto fillValueAttr = rewriter.getZeroAttr(type: type.getElementType());
2567 auto fillValue = rewriter.create<arith::ConstantOp>(location: loc, args&: fillValueAttr);
2568 auto filledTensor = rewriter
2569 .create<linalg::FillOp>(location: loc, args: ValueRange{fillValue},
2570 args: ValueRange{emptyTensor})
2571 .result();
2572 return filledTensor;
2573 }
2574
2575 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2576 FloatType type, Value value) {
2577 auto integerVal = builder.create<arith::IndexCastUIOp>(
2578 location: loc,
2579 args: type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2580 : builder.getI32Type(),
2581 args&: value);
2582
2583 return builder.create<arith::UIToFPOp>(location: loc, args&: type, args&: integerVal);
2584 }
2585
2586 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2587 FloatType type, int64_t index) {
2588 auto indexVal = builder.create<linalg::IndexOp>(location: loc, args&: index);
2589 return castIndexToFloat(builder, loc, type, value: indexVal);
2590 }
2591
2592 template <typename... Args>
2593 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2594 Args... args) {
2595 return {builder.getAffineDimExpr(position: args)...};
2596 }
2597
2598 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2599 PatternRewriter &rewriter) const override {
2600 if (!llvm::all_of(Range: rfft2d->getOperandTypes(), P: isRankedTensor) ||
2601 !llvm::all_of(Range: rfft2d->getResultTypes(), P: isRankedTensor)) {
2602 return rewriter.notifyMatchFailure(arg&: rfft2d,
2603 msg: "only supports ranked tensors");
2604 }
2605
2606 auto loc = rfft2d.getLoc();
2607 auto input = rfft2d.getInputReal();
2608 auto elementType =
2609 dyn_cast<FloatType>(Val: cast<ShapedType>(Val: input.getType()).getElementType());
2610 if (!elementType)
2611 return rewriter.notifyMatchFailure(arg&: rfft2d,
2612 msg: "only supports float element types");
2613
2614 // Compute the output type and set of dynamic sizes
2615 llvm::SmallVector<Value> dynamicSizes;
2616 auto outputType = computeOutputShape(builder&: rewriter, loc, input, dynamicSizes);
2617
2618 // Iterator types for the linalg.generic implementation
2619 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2620 utils::IteratorType::parallel, utils::IteratorType::parallel,
2621 utils::IteratorType::parallel, utils::IteratorType::reduction,
2622 utils::IteratorType::reduction};
2623
2624 // Inputs/outputs to the linalg.generic implementation
2625 llvm::SmallVector<Value> genericOpInputs = {input};
2626 llvm::SmallVector<Value> genericOpOutputs = {
2627 createZeroTensor(rewriter, loc, type: outputType, dynamicSizes),
2628 createZeroTensor(rewriter, loc, type: outputType, dynamicSizes)};
2629
2630 // Indexing maps for input and output tensors
2631 auto indexingMaps = AffineMap::inferFromExprList(
2632 exprsList: llvm::ArrayRef{affineDimsExpr(builder&: rewriter, args: 0, args: 3, args: 4),
2633 affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2),
2634 affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2)},
2635 context: rewriter.getContext());
2636
2637 // Width and height dimensions of the original input.
2638 auto dimH = rewriter.createOrFold<tensor::DimOp>(location: loc, args&: input, args: 1);
2639 auto dimW = rewriter.createOrFold<tensor::DimOp>(location: loc, args&: input, args: 2);
2640
2641 // Constants and dimension sizes
2642 auto twoPiAttr = rewriter.getFloatAttr(type: elementType, value: 6.283185307179586);
2643 auto twoPi = rewriter.create<arith::ConstantOp>(location: loc, args&: twoPiAttr);
2644 auto constH = castIndexToFloat(builder&: rewriter, loc, type: elementType, value: dimH);
2645 auto constW = castIndexToFloat(builder&: rewriter, loc, type: elementType, value: dimW);
2646
2647 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2648 Value valReal = args[0];
2649 Value sumReal = args[1];
2650 Value sumImag = args[2];
2651
2652 // Indices for angle computation
2653 Value oy = builder.create<linalg::IndexOp>(location: loc, args: 1);
2654 Value ox = builder.create<linalg::IndexOp>(location: loc, args: 2);
2655 Value iy = builder.create<linalg::IndexOp>(location: loc, args: 3);
2656 Value ix = builder.create<linalg::IndexOp>(location: loc, args: 4);
2657
2658 // Calculating angle without integer parts of components as sin/cos are
2659 // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2660 // / W);
2661 auto iyXoy = builder.create<index::MulOp>(location: loc, args&: iy, args&: oy);
2662 auto ixXox = builder.create<index::MulOp>(location: loc, args&: ix, args&: ox);
2663
2664 auto iyRem = builder.create<index::RemUOp>(location: loc, args&: iyXoy, args&: dimH);
2665 auto ixRem = builder.create<index::RemUOp>(location: loc, args&: ixXox, args&: dimW);
2666
2667 auto iyRemFloat = castIndexToFloat(builder, loc, type: elementType, value: iyRem);
2668 auto ixRemFloat = castIndexToFloat(builder, loc, type: elementType, value: ixRem);
2669
2670 auto yComponent = builder.create<arith::DivFOp>(location: loc, args&: iyRemFloat, args&: constH);
2671 auto xComponent = builder.create<arith::DivFOp>(location: loc, args&: ixRemFloat, args&: constW);
2672 auto sumXY = builder.create<arith::AddFOp>(location: loc, args&: yComponent, args&: xComponent);
2673 auto angle = builder.create<arith::MulFOp>(location: loc, args&: twoPi, args&: sumXY);
2674
2675 // realComponent = valReal * cos(angle)
2676 // imagComponent = valReal * sin(angle)
2677 auto cosAngle = builder.create<math::CosOp>(location: loc, args&: angle);
2678 auto sinAngle = builder.create<math::SinOp>(location: loc, args&: angle);
2679 auto realComponent =
2680 builder.create<arith::MulFOp>(location: loc, args&: valReal, args&: cosAngle);
2681 auto imagComponent =
2682 builder.create<arith::MulFOp>(location: loc, args&: valReal, args&: sinAngle);
2683
2684 // outReal = sumReal + realComponent
2685 // outImag = sumImag - imagComponent
2686 auto outReal = builder.create<arith::AddFOp>(location: loc, args&: sumReal, args&: realComponent);
2687 auto outImag = builder.create<arith::SubFOp>(location: loc, args&: sumImag, args&: imagComponent);
2688
2689 builder.create<linalg::YieldOp>(location: loc, args: ValueRange{outReal, outImag});
2690 };
2691
2692 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2693 op: rfft2d, args: rfft2d.getResultTypes(), args&: genericOpInputs, args&: genericOpOutputs,
2694 args&: indexingMaps, args&: iteratorTypes, args&: buildBody);
2695
2696 return success();
2697 }
2698};
2699
2700struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2701 using OpRewritePattern::OpRewritePattern;
2702
2703 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2704 PatternRewriter &rewriter) const override {
2705 if (!llvm::all_of(Range: fft2d->getOperandTypes(),
2706 P: RFFT2dConverter::isRankedTensor) ||
2707 !llvm::all_of(Range: fft2d->getResultTypes(),
2708 P: RFFT2dConverter::isRankedTensor)) {
2709 return rewriter.notifyMatchFailure(arg&: fft2d, msg: "only supports ranked tensors");
2710 }
2711
2712 Location loc = fft2d.getLoc();
2713 Value input_real = fft2d.getInputReal();
2714 Value input_imag = fft2d.getInputImag();
2715 BoolAttr inverse = fft2d.getInverseAttr();
2716
2717 auto real_el_ty = cast<FloatType>(
2718 Val: cast<ShapedType>(Val: input_real.getType()).getElementType());
2719 [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2720 Val: cast<ShapedType>(Val: input_imag.getType()).getElementType());
2721
2722 assert(real_el_ty == imag_el_ty);
2723
2724 // Compute the output type and set of dynamic sizes
2725 SmallVector<Value> dynamicSizes;
2726
2727 // Get [N, H, W]
2728 auto dims = tensor::getMixedSizes(builder&: rewriter, loc, value: input_real);
2729
2730 SmallVector<int64_t, 3> staticSizes;
2731 dispatchIndexOpFoldResults(ofrs: dims, dynamicVec&: dynamicSizes, staticVec&: staticSizes);
2732
2733 auto outputType = RankedTensorType::get(shape: staticSizes, elementType: real_el_ty);
2734
2735 // Iterator types for the linalg.generic implementation
2736 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2737 utils::IteratorType::parallel, utils::IteratorType::parallel,
2738 utils::IteratorType::parallel, utils::IteratorType::reduction,
2739 utils::IteratorType::reduction};
2740
2741 // Inputs/outputs to the linalg.generic implementation
2742 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2743 SmallVector<Value> genericOpOutputs = {
2744 RFFT2dConverter::createZeroTensor(rewriter, loc, type: outputType,
2745 dynamicSizes),
2746 RFFT2dConverter::createZeroTensor(rewriter, loc, type: outputType,
2747 dynamicSizes)};
2748
2749 // Indexing maps for input and output tensors
2750 auto indexingMaps = AffineMap::inferFromExprList(
2751 exprsList: ArrayRef{RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 3, args: 4),
2752 RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 3, args: 4),
2753 RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2),
2754 RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2)},
2755 context: rewriter.getContext());
2756
2757 // Width and height dimensions of the original input.
2758 auto dimH = rewriter.createOrFold<tensor::DimOp>(location: loc, args&: input_real, args: 1);
2759 auto dimW = rewriter.createOrFold<tensor::DimOp>(location: loc, args&: input_real, args: 2);
2760
2761 // Constants and dimension sizes
2762 auto twoPiAttr = rewriter.getFloatAttr(type: real_el_ty, value: 6.283185307179586);
2763 auto twoPi = rewriter.create<arith::ConstantOp>(location: loc, args&: twoPiAttr);
2764 Value constH =
2765 RFFT2dConverter::castIndexToFloat(builder&: rewriter, loc, type: real_el_ty, value: dimH);
2766 Value constW =
2767 RFFT2dConverter::castIndexToFloat(builder&: rewriter, loc, type: real_el_ty, value: dimW);
2768
2769 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2770 Value valReal = args[0];
2771 Value valImag = args[1];
2772 Value sumReal = args[2];
2773 Value sumImag = args[3];
2774
2775 // Indices for angle computation
2776 Value oy = builder.create<linalg::IndexOp>(location: loc, args: 1);
2777 Value ox = builder.create<linalg::IndexOp>(location: loc, args: 2);
2778 Value iy = builder.create<linalg::IndexOp>(location: loc, args: 3);
2779 Value ix = builder.create<linalg::IndexOp>(location: loc, args: 4);
2780
2781 // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2782 // ox) % W ) / W);
2783 auto iyXoy = builder.create<index::MulOp>(location: loc, args&: iy, args&: oy);
2784 auto ixXox = builder.create<index::MulOp>(location: loc, args&: ix, args&: ox);
2785
2786 auto iyRem = builder.create<index::RemUOp>(location: loc, args&: iyXoy, args&: dimH);
2787 auto ixRem = builder.create<index::RemUOp>(location: loc, args&: ixXox, args&: dimW);
2788
2789 auto iyRemFloat =
2790 RFFT2dConverter::castIndexToFloat(builder, loc, type: real_el_ty, value: iyRem);
2791 auto ixRemFloat =
2792 RFFT2dConverter::castIndexToFloat(builder, loc, type: real_el_ty, value: ixRem);
2793
2794 auto yComponent = builder.create<arith::DivFOp>(location: loc, args&: iyRemFloat, args&: constH);
2795 auto xComponent = builder.create<arith::DivFOp>(location: loc, args&: ixRemFloat, args&: constW);
2796
2797 auto sumXY = builder.create<arith::AddFOp>(location: loc, args&: yComponent, args&: xComponent);
2798 auto angle = builder.create<arith::MulFOp>(location: loc, args&: twoPi, args&: sumXY);
2799
2800 if (inverse.getValue()) {
2801 angle = builder.create<arith::MulFOp>(
2802 location: loc, args&: angle,
2803 args: rewriter.create<arith::ConstantOp>(
2804 location: loc, args: rewriter.getFloatAttr(type: real_el_ty, value: -1.0)));
2805 }
2806
2807 // realComponent = val_real * cos(a) + val_imag * sin(a);
2808 // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2809 auto cosAngle = builder.create<math::CosOp>(location: loc, args&: angle);
2810 auto sinAngle = builder.create<math::SinOp>(location: loc, args&: angle);
2811
2812 auto rcos = builder.create<arith::MulFOp>(location: loc, args&: valReal, args&: cosAngle);
2813 auto rsin = builder.create<arith::MulFOp>(location: loc, args&: valImag, args&: sinAngle);
2814 auto realComponent = builder.create<arith::AddFOp>(location: loc, args&: rcos, args&: rsin);
2815
2816 auto icos = builder.create<arith::MulFOp>(location: loc, args&: valImag, args&: cosAngle);
2817 auto isin = builder.create<arith::MulFOp>(location: loc, args&: valReal, args&: sinAngle);
2818
2819 auto imagComponent = builder.create<arith::SubFOp>(location: loc, args&: icos, args&: isin);
2820
2821 // outReal = sumReal + realComponent
2822 // outImag = sumImag - imagComponent
2823 auto outReal = builder.create<arith::AddFOp>(location: loc, args&: sumReal, args&: realComponent);
2824 auto outImag = builder.create<arith::AddFOp>(location: loc, args&: sumImag, args&: imagComponent);
2825
2826 builder.create<linalg::YieldOp>(location: loc, args: ValueRange{outReal, outImag});
2827 };
2828
2829 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2830 op: fft2d, args: fft2d.getResultTypes(), args&: genericOpInputs, args&: genericOpOutputs,
2831 args&: indexingMaps, args&: iteratorTypes, args&: buildBody);
2832
2833 return success();
2834 }
2835};
2836
2837} // namespace
2838
2839void mlir::tosa::populateTosaToLinalgConversionPatterns(
2840 const TypeConverter &converter, RewritePatternSet *patterns) {
2841
2842 // We have multiple resize coverters to handle degenerate cases.
2843 patterns->add<GenericResizeConverter>(arg: patterns->getContext(),
2844 /*benefit=*/args: 100);
2845 patterns->add<ResizeUnaryConverter>(arg: patterns->getContext(),
2846 /*benefit=*/args: 200);
2847 patterns->add<MaterializeResizeBroadcast>(arg: patterns->getContext(),
2848 /*benefit=*/args: 300);
2849
2850 patterns->add<
2851 // clang-format off
2852 PointwiseConverter<tosa::AddOp>,
2853 PointwiseConverter<tosa::SubOp>,
2854 PointwiseConverter<tosa::MulOp>,
2855 PointwiseConverter<tosa::IntDivOp>,
2856 PointwiseConverter<tosa::NegateOp>,
2857 PointwiseConverter<tosa::PowOp>,
2858 PointwiseConverter<tosa::ReciprocalOp>,
2859 PointwiseConverter<tosa::RsqrtOp>,
2860 PointwiseConverter<tosa::LogOp>,
2861 PointwiseConverter<tosa::ExpOp>,
2862 PointwiseConverter<tosa::AbsOp>,
2863 PointwiseConverter<tosa::SinOp>,
2864 PointwiseConverter<tosa::CosOp>,
2865 PointwiseConverter<tosa::TanhOp>,
2866 PointwiseConverter<tosa::ErfOp>,
2867 PointwiseConverter<tosa::BitwiseAndOp>,
2868 PointwiseConverter<tosa::BitwiseOrOp>,
2869 PointwiseConverter<tosa::BitwiseNotOp>,
2870 PointwiseConverter<tosa::BitwiseXorOp>,
2871 PointwiseConverter<tosa::LogicalAndOp>,
2872 PointwiseConverter<tosa::LogicalNotOp>,
2873 PointwiseConverter<tosa::LogicalOrOp>,
2874 PointwiseConverter<tosa::LogicalXorOp>,
2875 PointwiseConverter<tosa::CastOp>,
2876 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2877 PointwiseConverter<tosa::LogicalRightShiftOp>,
2878 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2879 PointwiseConverter<tosa::ClzOp>,
2880 PointwiseConverter<tosa::SelectOp>,
2881 PointwiseConverter<tosa::GreaterOp>,
2882 PointwiseConverter<tosa::GreaterEqualOp>,
2883 PointwiseConverter<tosa::EqualOp>,
2884 PointwiseConverter<tosa::MaximumOp>,
2885 PointwiseConverter<tosa::MinimumOp>,
2886 PointwiseConverter<tosa::CeilOp>,
2887 PointwiseConverter<tosa::FloorOp>,
2888 PointwiseConverter<tosa::ClampOp>,
2889 PointwiseConverter<tosa::SigmoidOp>
2890 >(arg: converter, args: patterns->getContext());
2891
2892 patterns->add<
2893 IdentityNConverter<tosa::IdentityOp>,
2894 ReduceConverter<tosa::ReduceAllOp>,
2895 ReduceConverter<tosa::ReduceAnyOp>,
2896 ReduceConverter<tosa::ReduceMinOp>,
2897 ReduceConverter<tosa::ReduceMaxOp>,
2898 ReduceConverter<tosa::ReduceSumOp>,
2899 ReduceConverter<tosa::ReduceProductOp>,
2900 ArgMaxConverter,
2901 GatherConverter,
2902 RescaleConverter,
2903 ReverseConverter,
2904 RFFT2dConverter,
2905 FFT2dConverter,
2906 TableConverter,
2907 TileConverter>(arg: patterns->getContext());
2908 // clang-format on
2909}
2910

source code of mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp